File size: 4,736 Bytes
8d18b7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
"""Utility functions for data processing."""

import json
import logging
from pathlib import Path
from typing import Any, Dict, List, Optional

import numpy as np

logger = logging.getLogger(__name__)


def load_jsonl(file_path: str) -> List[Dict[str, Any]]:
    """Load JSONL file."""
    data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if line:
                try:
                    data.append(json.loads(line))
                except json.JSONDecodeError as e:
                    logger.warning(f"Failed to parse line: {e}")
    return data


def save_jsonl(data: List[Dict[str, Any]], file_path: str):
    """Save to JSONL file."""
    Path(file_path).parent.mkdir(parents=True, exist_ok=True)
    with open(file_path, 'w', encoding='utf-8') as f:
        for item in data:
            f.write(json.dumps(item, ensure_ascii=False) + '\n')


def load_json(file_path: str) -> Any:
    """Load JSON file."""
    with open(file_path, 'r', encoding='utf-8') as f:
        return json.load(f)


def save_json(data: Any, file_path: str):
    """Save to JSON file."""
    Path(file_path).parent.mkdir(parents=True, exist_ok=True)
    with open(file_path, 'w', encoding='utf-8') as f:
        json.dump(data, f, ensure_ascii=False, indent=2)


def merge_datasets(

    datasets: List[Any],

    weights: Optional[List[float]] = None,

) -> List[Dict[str, Any]]:
    """Merge multiple datasets with optional weighting."""
    if weights is None:
        weights = [1.0] * len(datasets)

    # Normalize weights
    total_weight = sum(weights)
    weights = [w / total_weight for w in weights]

    merged = []
    for ds, weight in zip(datasets, weights):
        # Calculate number of samples to take
        n_samples = int(len(ds) * weight)
        if n_samples > len(ds):
            n_samples = len(ds)

        # Sample
        indices = np.random.choice(len(ds), size=n_samples, replace=False)
        for idx in indices:
            merged.append(ds[idx])

    logger.info(f"Merged dataset size: {len(merged)}")
    return merged


def compute_dataset_statistics(

    dataset: Any,

    text_key: str = "text",

) -> Dict[str, Any]:
    """Compute comprehensive statistics for dataset."""
    lengths = []
    domains = []
    has_thoughts = []

    for sample in dataset:
        text = sample.get(text_key, "")
        lengths.append(len(text.split()))

        domain = sample.get("domain", "unknown")
        domains.append(domain)

        has_thoughts.append(1 if sample.get("thoughts") else 0)

    return {
        "num_samples": len(dataset),
        "length_stats": compute_length_statistics(lengths),
        "domain_distribution": {d: domains.count(d) / len(domains) for d in set(domains)},
        "thoughts_coverage": sum(has_thoughts) / len(has_thoughts),
    }


def validate_dataset(dataset: Any, required_keys: List[str]) -> List[str]:
    """Validate dataset structure."""
    errors = []

    for i, sample in enumerate(dataset):
        for key in required_keys:
            if key not in sample:
                errors.append(f"Sample {i} missing required key: {key}")

    return errors


def deduplicate_dataset(dataset: List[Dict[str, Any]], key: str = "text") -> List[Dict[str, Any]]:
    """Remove duplicate samples based on key."""
    seen = set()
    deduplicated = []

    for sample in dataset:
        value = sample.get(key, "")
        if value not in seen:
            seen.add(value)
            deduplicated.append(sample)

    logger.info(f"Deduplicated: {len(dataset)} -> {len(deduplicated)}")
    return deduplicated


def balance_dataset(

    dataset: List[Dict[str, Any]],

    by_key: str = "domain",

    max_per_category: Optional[int] = None,

) -> List[Dict[str, Any]]:
    """Balance dataset across categories."""
    categories = {}
    for sample in dataset:
        category = sample.get(by_key, "unknown")
        if category not in categories:
            categories[category] = []
        categories[category].append(sample)

    # Determine max samples per category
    if max_per_category is None:
        max_per_category = min(len(cat) for cat in categories.values())

    # Sample equally from each category
    balanced = []
    for category, samples in categories.items():
        if len(samples) > max_per_category:
            samples = np.random.choice(samples, size=max_per_category, replace=False).tolist()
        balanced.extend(samples)

    logger.info(f"Balanced dataset: {len(dataset)} -> {len(balanced)}")
    return balanced