File size: 4,736 Bytes
8d18b7c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 | """Utility functions for data processing."""
import json
import logging
from pathlib import Path
from typing import Any, Dict, List, Optional
import numpy as np
logger = logging.getLogger(__name__)
def load_jsonl(file_path: str) -> List[Dict[str, Any]]:
"""Load JSONL file."""
data = []
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if line:
try:
data.append(json.loads(line))
except json.JSONDecodeError as e:
logger.warning(f"Failed to parse line: {e}")
return data
def save_jsonl(data: List[Dict[str, Any]], file_path: str):
"""Save to JSONL file."""
Path(file_path).parent.mkdir(parents=True, exist_ok=True)
with open(file_path, 'w', encoding='utf-8') as f:
for item in data:
f.write(json.dumps(item, ensure_ascii=False) + '\n')
def load_json(file_path: str) -> Any:
"""Load JSON file."""
with open(file_path, 'r', encoding='utf-8') as f:
return json.load(f)
def save_json(data: Any, file_path: str):
"""Save to JSON file."""
Path(file_path).parent.mkdir(parents=True, exist_ok=True)
with open(file_path, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2)
def merge_datasets(
datasets: List[Any],
weights: Optional[List[float]] = None,
) -> List[Dict[str, Any]]:
"""Merge multiple datasets with optional weighting."""
if weights is None:
weights = [1.0] * len(datasets)
# Normalize weights
total_weight = sum(weights)
weights = [w / total_weight for w in weights]
merged = []
for ds, weight in zip(datasets, weights):
# Calculate number of samples to take
n_samples = int(len(ds) * weight)
if n_samples > len(ds):
n_samples = len(ds)
# Sample
indices = np.random.choice(len(ds), size=n_samples, replace=False)
for idx in indices:
merged.append(ds[idx])
logger.info(f"Merged dataset size: {len(merged)}")
return merged
def compute_dataset_statistics(
dataset: Any,
text_key: str = "text",
) -> Dict[str, Any]:
"""Compute comprehensive statistics for dataset."""
lengths = []
domains = []
has_thoughts = []
for sample in dataset:
text = sample.get(text_key, "")
lengths.append(len(text.split()))
domain = sample.get("domain", "unknown")
domains.append(domain)
has_thoughts.append(1 if sample.get("thoughts") else 0)
return {
"num_samples": len(dataset),
"length_stats": compute_length_statistics(lengths),
"domain_distribution": {d: domains.count(d) / len(domains) for d in set(domains)},
"thoughts_coverage": sum(has_thoughts) / len(has_thoughts),
}
def validate_dataset(dataset: Any, required_keys: List[str]) -> List[str]:
"""Validate dataset structure."""
errors = []
for i, sample in enumerate(dataset):
for key in required_keys:
if key not in sample:
errors.append(f"Sample {i} missing required key: {key}")
return errors
def deduplicate_dataset(dataset: List[Dict[str, Any]], key: str = "text") -> List[Dict[str, Any]]:
"""Remove duplicate samples based on key."""
seen = set()
deduplicated = []
for sample in dataset:
value = sample.get(key, "")
if value not in seen:
seen.add(value)
deduplicated.append(sample)
logger.info(f"Deduplicated: {len(dataset)} -> {len(deduplicated)}")
return deduplicated
def balance_dataset(
dataset: List[Dict[str, Any]],
by_key: str = "domain",
max_per_category: Optional[int] = None,
) -> List[Dict[str, Any]]:
"""Balance dataset across categories."""
categories = {}
for sample in dataset:
category = sample.get(by_key, "unknown")
if category not in categories:
categories[category] = []
categories[category].append(sample)
# Determine max samples per category
if max_per_category is None:
max_per_category = min(len(cat) for cat in categories.values())
# Sample equally from each category
balanced = []
for category, samples in categories.items():
if len(samples) > max_per_category:
samples = np.random.choice(samples, size=max_per_category, replace=False).tolist()
balanced.extend(samples)
logger.info(f"Balanced dataset: {len(dataset)} -> {len(balanced)}")
return balanced
|