| | """Utility functions for data processing."""
|
| |
|
| | import json
|
| | import logging
|
| | from pathlib import Path
|
| | from typing import Any, Dict, List, Optional
|
| |
|
| | import numpy as np
|
| |
|
| | logger = logging.getLogger(__name__)
|
| |
|
| |
|
| | def load_jsonl(file_path: str) -> List[Dict[str, Any]]:
|
| | """Load JSONL file."""
|
| | data = []
|
| | with open(file_path, 'r', encoding='utf-8') as f:
|
| | for line in f:
|
| | line = line.strip()
|
| | if line:
|
| | try:
|
| | data.append(json.loads(line))
|
| | except json.JSONDecodeError as e:
|
| | logger.warning(f"Failed to parse line: {e}")
|
| | return data
|
| |
|
| |
|
| | def save_jsonl(data: List[Dict[str, Any]], file_path: str):
|
| | """Save to JSONL file."""
|
| | Path(file_path).parent.mkdir(parents=True, exist_ok=True)
|
| | with open(file_path, 'w', encoding='utf-8') as f:
|
| | for item in data:
|
| | f.write(json.dumps(item, ensure_ascii=False) + '\n')
|
| |
|
| |
|
| | def load_json(file_path: str) -> Any:
|
| | """Load JSON file."""
|
| | with open(file_path, 'r', encoding='utf-8') as f:
|
| | return json.load(f)
|
| |
|
| |
|
| | def save_json(data: Any, file_path: str):
|
| | """Save to JSON file."""
|
| | Path(file_path).parent.mkdir(parents=True, exist_ok=True)
|
| | with open(file_path, 'w', encoding='utf-8') as f:
|
| | json.dump(data, f, ensure_ascii=False, indent=2)
|
| |
|
| |
|
| | def merge_datasets(
|
| | datasets: List[Any],
|
| | weights: Optional[List[float]] = None,
|
| | ) -> List[Dict[str, Any]]:
|
| | """Merge multiple datasets with optional weighting."""
|
| | if weights is None:
|
| | weights = [1.0] * len(datasets)
|
| |
|
| |
|
| | total_weight = sum(weights)
|
| | weights = [w / total_weight for w in weights]
|
| |
|
| | merged = []
|
| | for ds, weight in zip(datasets, weights):
|
| |
|
| | n_samples = int(len(ds) * weight)
|
| | if n_samples > len(ds):
|
| | n_samples = len(ds)
|
| |
|
| |
|
| | indices = np.random.choice(len(ds), size=n_samples, replace=False)
|
| | for idx in indices:
|
| | merged.append(ds[idx])
|
| |
|
| | logger.info(f"Merged dataset size: {len(merged)}")
|
| | return merged
|
| |
|
| |
|
| | def compute_dataset_statistics(
|
| | dataset: Any,
|
| | text_key: str = "text",
|
| | ) -> Dict[str, Any]:
|
| | """Compute comprehensive statistics for dataset."""
|
| | lengths = []
|
| | domains = []
|
| | has_thoughts = []
|
| |
|
| | for sample in dataset:
|
| | text = sample.get(text_key, "")
|
| | lengths.append(len(text.split()))
|
| |
|
| | domain = sample.get("domain", "unknown")
|
| | domains.append(domain)
|
| |
|
| | has_thoughts.append(1 if sample.get("thoughts") else 0)
|
| |
|
| | return {
|
| | "num_samples": len(dataset),
|
| | "length_stats": compute_length_statistics(lengths),
|
| | "domain_distribution": {d: domains.count(d) / len(domains) for d in set(domains)},
|
| | "thoughts_coverage": sum(has_thoughts) / len(has_thoughts),
|
| | }
|
| |
|
| |
|
| | def validate_dataset(dataset: Any, required_keys: List[str]) -> List[str]:
|
| | """Validate dataset structure."""
|
| | errors = []
|
| |
|
| | for i, sample in enumerate(dataset):
|
| | for key in required_keys:
|
| | if key not in sample:
|
| | errors.append(f"Sample {i} missing required key: {key}")
|
| |
|
| | return errors
|
| |
|
| |
|
| | def deduplicate_dataset(dataset: List[Dict[str, Any]], key: str = "text") -> List[Dict[str, Any]]:
|
| | """Remove duplicate samples based on key."""
|
| | seen = set()
|
| | deduplicated = []
|
| |
|
| | for sample in dataset:
|
| | value = sample.get(key, "")
|
| | if value not in seen:
|
| | seen.add(value)
|
| | deduplicated.append(sample)
|
| |
|
| | logger.info(f"Deduplicated: {len(dataset)} -> {len(deduplicated)}")
|
| | return deduplicated
|
| |
|
| |
|
| | def balance_dataset(
|
| | dataset: List[Dict[str, Any]],
|
| | by_key: str = "domain",
|
| | max_per_category: Optional[int] = None,
|
| | ) -> List[Dict[str, Any]]:
|
| | """Balance dataset across categories."""
|
| | categories = {}
|
| | for sample in dataset:
|
| | category = sample.get(by_key, "unknown")
|
| | if category not in categories:
|
| | categories[category] = []
|
| | categories[category].append(sample)
|
| |
|
| |
|
| | if max_per_category is None:
|
| | max_per_category = min(len(cat) for cat in categories.values())
|
| |
|
| |
|
| | balanced = []
|
| | for category, samples in categories.items():
|
| | if len(samples) > max_per_category:
|
| | samples = np.random.choice(samples, size=max_per_category, replace=False).tolist()
|
| | balanced.extend(samples)
|
| |
|
| | logger.info(f"Balanced dataset: {len(dataset)} -> {len(balanced)}")
|
| | return balanced
|
| |
|