"""Utility functions for data processing.""" import json import logging from pathlib import Path from typing import Any, Dict, List, Optional import numpy as np logger = logging.getLogger(__name__) def load_jsonl(file_path: str) -> List[Dict[str, Any]]: """Load JSONL file.""" data = [] with open(file_path, 'r', encoding='utf-8') as f: for line in f: line = line.strip() if line: try: data.append(json.loads(line)) except json.JSONDecodeError as e: logger.warning(f"Failed to parse line: {e}") return data def save_jsonl(data: List[Dict[str, Any]], file_path: str): """Save to JSONL file.""" Path(file_path).parent.mkdir(parents=True, exist_ok=True) with open(file_path, 'w', encoding='utf-8') as f: for item in data: f.write(json.dumps(item, ensure_ascii=False) + '\n') def load_json(file_path: str) -> Any: """Load JSON file.""" with open(file_path, 'r', encoding='utf-8') as f: return json.load(f) def save_json(data: Any, file_path: str): """Save to JSON file.""" Path(file_path).parent.mkdir(parents=True, exist_ok=True) with open(file_path, 'w', encoding='utf-8') as f: json.dump(data, f, ensure_ascii=False, indent=2) def merge_datasets( datasets: List[Any], weights: Optional[List[float]] = None, ) -> List[Dict[str, Any]]: """Merge multiple datasets with optional weighting.""" if weights is None: weights = [1.0] * len(datasets) # Normalize weights total_weight = sum(weights) weights = [w / total_weight for w in weights] merged = [] for ds, weight in zip(datasets, weights): # Calculate number of samples to take n_samples = int(len(ds) * weight) if n_samples > len(ds): n_samples = len(ds) # Sample indices = np.random.choice(len(ds), size=n_samples, replace=False) for idx in indices: merged.append(ds[idx]) logger.info(f"Merged dataset size: {len(merged)}") return merged def compute_dataset_statistics( dataset: Any, text_key: str = "text", ) -> Dict[str, Any]: """Compute comprehensive statistics for dataset.""" lengths = [] domains = [] has_thoughts = [] for sample in dataset: text = sample.get(text_key, "") lengths.append(len(text.split())) domain = sample.get("domain", "unknown") domains.append(domain) has_thoughts.append(1 if sample.get("thoughts") else 0) return { "num_samples": len(dataset), "length_stats": compute_length_statistics(lengths), "domain_distribution": {d: domains.count(d) / len(domains) for d in set(domains)}, "thoughts_coverage": sum(has_thoughts) / len(has_thoughts), } def validate_dataset(dataset: Any, required_keys: List[str]) -> List[str]: """Validate dataset structure.""" errors = [] for i, sample in enumerate(dataset): for key in required_keys: if key not in sample: errors.append(f"Sample {i} missing required key: {key}") return errors def deduplicate_dataset(dataset: List[Dict[str, Any]], key: str = "text") -> List[Dict[str, Any]]: """Remove duplicate samples based on key.""" seen = set() deduplicated = [] for sample in dataset: value = sample.get(key, "") if value not in seen: seen.add(value) deduplicated.append(sample) logger.info(f"Deduplicated: {len(dataset)} -> {len(deduplicated)}") return deduplicated def balance_dataset( dataset: List[Dict[str, Any]], by_key: str = "domain", max_per_category: Optional[int] = None, ) -> List[Dict[str, Any]]: """Balance dataset across categories.""" categories = {} for sample in dataset: category = sample.get(by_key, "unknown") if category not in categories: categories[category] = [] categories[category].append(sample) # Determine max samples per category if max_per_category is None: max_per_category = min(len(cat) for cat in categories.values()) # Sample equally from each category balanced = [] for category, samples in categories.items(): if len(samples) > max_per_category: samples = np.random.choice(samples, size=max_per_category, replace=False).tolist() balanced.extend(samples) logger.info(f"Balanced dataset: {len(dataset)} -> {len(balanced)}") return balanced