Zandy-Wandy's picture
Upload Zenith-28b-V1-Tenstorrent-Blackhole-p300 model
8944ef7 verified
"""Utility functions for data processing."""
import json
import logging
from pathlib import Path
from typing import Any, Dict, List, Optional
import numpy as np
logger = logging.getLogger(__name__)
def load_jsonl(file_path: str) -> List[Dict[str, Any]]:
"""Load JSONL file."""
data = []
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if line:
try:
data.append(json.loads(line))
except json.JSONDecodeError as e:
logger.warning(f"Failed to parse line: {e}")
return data
def save_jsonl(data: List[Dict[str, Any]], file_path: str):
"""Save to JSONL file."""
Path(file_path).parent.mkdir(parents=True, exist_ok=True)
with open(file_path, 'w', encoding='utf-8') as f:
for item in data:
f.write(json.dumps(item, ensure_ascii=False) + '\n')
def load_json(file_path: str) -> Any:
"""Load JSON file."""
with open(file_path, 'r', encoding='utf-8') as f:
return json.load(f)
def save_json(data: Any, file_path: str):
"""Save to JSON file."""
Path(file_path).parent.mkdir(parents=True, exist_ok=True)
with open(file_path, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2)
def merge_datasets(
datasets: List[Any],
weights: Optional[List[float]] = None,
) -> List[Dict[str, Any]]:
"""Merge multiple datasets with optional weighting."""
if weights is None:
weights = [1.0] * len(datasets)
# Normalize weights
total_weight = sum(weights)
weights = [w / total_weight for w in weights]
merged = []
for ds, weight in zip(datasets, weights):
# Calculate number of samples to take
n_samples = int(len(ds) * weight)
if n_samples > len(ds):
n_samples = len(ds)
# Sample
indices = np.random.choice(len(ds), size=n_samples, replace=False)
for idx in indices:
merged.append(ds[idx])
logger.info(f"Merged dataset size: {len(merged)}")
return merged
def compute_dataset_statistics(
dataset: Any,
text_key: str = "text",
) -> Dict[str, Any]:
"""Compute comprehensive statistics for dataset."""
lengths = []
domains = []
has_thoughts = []
for sample in dataset:
text = sample.get(text_key, "")
lengths.append(len(text.split()))
domain = sample.get("domain", "unknown")
domains.append(domain)
has_thoughts.append(1 if sample.get("thoughts") else 0)
return {
"num_samples": len(dataset),
"length_stats": compute_length_statistics(lengths),
"domain_distribution": {d: domains.count(d) / len(domains) for d in set(domains)},
"thoughts_coverage": sum(has_thoughts) / len(has_thoughts),
}
def validate_dataset(dataset: Any, required_keys: List[str]) -> List[str]:
"""Validate dataset structure."""
errors = []
for i, sample in enumerate(dataset):
for key in required_keys:
if key not in sample:
errors.append(f"Sample {i} missing required key: {key}")
return errors
def deduplicate_dataset(dataset: List[Dict[str, Any]], key: str = "text") -> List[Dict[str, Any]]:
"""Remove duplicate samples based on key."""
seen = set()
deduplicated = []
for sample in dataset:
value = sample.get(key, "")
if value not in seen:
seen.add(value)
deduplicated.append(sample)
logger.info(f"Deduplicated: {len(dataset)} -> {len(deduplicated)}")
return deduplicated
def balance_dataset(
dataset: List[Dict[str, Any]],
by_key: str = "domain",
max_per_category: Optional[int] = None,
) -> List[Dict[str, Any]]:
"""Balance dataset across categories."""
categories = {}
for sample in dataset:
category = sample.get(by_key, "unknown")
if category not in categories:
categories[category] = []
categories[category].append(sample)
# Determine max samples per category
if max_per_category is None:
max_per_category = min(len(cat) for cat in categories.values())
# Sample equally from each category
balanced = []
for category, samples in categories.items():
if len(samples) > max_per_category:
samples = np.random.choice(samples, size=max_per_category, replace=False).tolist()
balanced.extend(samples)
logger.info(f"Balanced dataset: {len(dataset)} -> {len(balanced)}")
return balanced