File size: 10,600 Bytes
02a1aee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
"""Advanced Preprocessing for OpenThoughts and Custom Datasets"""

import json
import logging
import re
from typing import Any, Dict, List, Optional, Tuple

logger = logging.getLogger(__name__)

# Special token markers
THOUGHT_START = "<think>"
THOUGHT_END = "</think>"
USER_START = "<user>"
USER_END = "</user>"
ASSISTANT_START = "<assistant>"
ASSISTANT_END = "</assistant>"
SYSTEM_START = "<system>"
SYSTEM_END = "</system>"


def preprocess_conversation(

    conversations: Any,

    include_thoughts: bool = True,

    include_reasoning: bool = True,

) -> Dict[str, Any]:
    """Preprocess conversation data into training format."""
    if isinstance(conversations, str):
        try:
            conversations = json.loads(conversations)
        except json.JSONDecodeError:
            return {"text": conversations, "conversations": []}

    if not isinstance(conversations, list):
        return {"text": str(conversations), "conversations": []}

    processed_messages = []
    thoughts = []
    reasoning = ""

    for msg in conversations:
        if not isinstance(msg, dict):
            continue

        role = msg.get("role", "").lower()
        content = msg.get("content", "")

        if not content:
            continue

        # Extract thoughts if present
        if include_thoughts and THOUGHT_START in content:
            thought_parts = re.findall(r'<think>(.*?)</think>', content, re.DOTALL)
            thoughts.extend([t.strip() for t in thought_parts if t.strip()])
            # Remove thought tags from content
            content = re.sub(r'<think>.*?</think>', '', content, flags=re.DOTALL).strip()

        # Format with special tokens
        if role == "user":
            formatted = f"{USER_START} {content} {USER_END}"
        elif role == "assistant":
            formatted = f"{ASSISTANT_START} {content} {ASSISTANT_END}"
        elif role == "system":
            formatted = f"{SYSTEM_START} {content} {SYSTEM_END}"
        else:
            formatted = content

        processed_messages.append({
            "role": role,
            "content": content,
            "formatted": formatted,
        })

    # Combine into single text
    text = "\n".join(msg["formatted"] for msg in processed_messages)

    result = {
        "text": text,
        "conversations": processed_messages,
    }

    if include_thoughts and thoughts:
        result["thoughts"] = " ".join(thoughts)

    if include_reasoning and reasoning:
        result["reasoning"] = reasoning

    return result


def extract_thoughts(text: str) -> str:
    """Extract chain-of-thought from text."""
    pattern = re.compile(r'<think>(.*?)</think>', re.DOTALL)
    thoughts = pattern.findall(text)
    return " ".join(t.strip() for t in thoughts if t.strip())


def format_for_training(

    sample: Dict[str, Any],

    include_thoughts: bool = True,

    include_reasoning: bool = True,

) -> str:
    """Format sample for model training."""
    if "text" in sample:
        text = sample["text"]
    elif "conversations" in sample:
        text = preprocess_conversation(sample["conversations"], include_thoughts, include_reasoning)["text"]
    elif "content" in sample:
        text = sample["content"]
    else:
        text = ""

    # Add thoughts if available and requested
    if include_thoughts and "thoughts" in sample and sample["thoughts"]:
        text += f"\n{THOUGHT_START} {sample['thoughts']} {THOUGHT_END}"

    return text


def detect_domain(conversations: Any) -> str:
    """Detect domain of conversation based on content."""
    if isinstance(conversations, str):
        try:
            conversations = json.loads(conversations)
        except:
            conversations = []

    text = ""
    for msg in conversations:
        if isinstance(msg, dict):
            text += msg.get("content", "") + " "

    text = text.lower()

    # Domain keywords
    domain_keywords = {
        "code": ["def ", "class ", "import ", "function", "return", "if __name__", "```python", "```java", "```cpp"],
        "mathematics": ["equation", "theorem", "proof", "calculate", "solve", "integral", "derivative", "matrix", "vector"],
        "science": ["experiment", "hypothesis", "theory", "data", "analysis", "chemical", "physical", "biological"],
        "reasoning": ["because", "therefore", "thus", "hence", "since", "logic", "deduce", "infer"],
        "dialogue": ["how are you", "what do you think", "please help", "thank you", "could you"],
    }

    scores = {}
    for domain, keywords in domain_keywords.items():
        score = sum(1 for kw in keywords if kw in text)
        scores[domain] = score

    if not scores:
        return "unknown"

    return max(scores, key=scores.get)


def estimate_difficulty(conversations: Any, thoughts: str = "") -> float:
    """Estimate difficulty on scale 0-1."""
    if isinstance(conversations, str):
        try:
            conversations = json.loads(conversations)
        except:
            conversations = []

    text = ""
    for msg in conversations:
        if isinstance(msg, dict):
            text += msg.get("content", "") + " "

    text += thoughts

    # Features for difficulty
    features = {
        "length": len(text.split()),
        "technical_terms": len(re.findall(r'\b[A-Z][a-z]+(?:[A-Z][a-z]+)+\b', text)),  # CamelCase
        "code_blocks": len(re.findall(r'```[\s\S]*?```', text)),
        "math_symbols": len(re.findall[r'[=+\-*/<>≤≥≠∈∉⊂⊆∪∩]', text]),
        "reasoning_markers": len(re.findall(r'\b(because|therefore|thus|hence|since)\b', text, re.IGNORECASE)),
    }

    # Normalize and combine
    difficulty = (
        min(features["length"] / 500, 1.0) * 0.3 +
        min(features["technical_terms"] / 20, 1.0) * 0.25 +
        min(features["code_blocks"] / 3, 1.0) * 0.25 +
        min(features["math_symbols"] / 10, 1.0) * 0.1 +
        min(features["reasoning_markers"] / 5, 1.0) * 0.1
    )

    return min(difficulty, 1.0)


def clean_text(text: str) -> str:
    """Clean and normalize text."""
    # Remove excessive whitespace
    text = re.sub(r'\s+', ' ', text)

    # Remove control characters
    text = re.sub(r'[\x00-\x1F\x7F]', '', text)

    # Normalize quotes
    text = text.replace('"', '"').replace('"', '"')
    text = text.replace(''', "'").replace(''', "'")

    # Strip
    text = text.strip()

    return text


def truncate_with_overlap(

    text: str,

    max_length: int,

    stride: int,

    tokenizer: Any,

) -> List[Dict[str, Any]]:
    """Truncate long text with overlapping windows."""
    tokens = tokenizer.encode(text, add_special_tokens=False)

    if len(tokens) <= max_length:
        return [{"input_ids": tokens, "attention_mask": [1] * len(tokens)}]

    chunks = []
    start = 0

    while start < len(tokens):
        end = min(start + max_length, len(tokens))
        chunk_tokens = tokens[start:end]

        chunks.append({
            "input_ids": chunk_tokens,
            "attention_mask": [1] * len(chunk_tokens),
        })

        if end >= len(tokens):
            break

        start += stride

    return chunks


def compute_length_statistics(lengths: List[int]) -> Dict[str, float]:
    """Compute statistics for length distribution."""
    import numpy as np

    if not lengths:
        return {}

    arr = np.array(lengths)
    return {
        "mean": float(np.mean(arr)),
        "std": float(np.std(arr)),
        "min": float(np.min(arr)),
        "max": float(np.max(arr)),
        "p50": float(np.percentile(arr, 50)),
        "p90": float(np.percentile(arr, 90)),
        "p95": float(np.percentile(arr, 95)),
        "p99": float(np.percentile(arr, 99)),
    }


def analyze_dataset_quality(dataset: Any, sample_size: int = 1000) -> Dict[str, Any]:
    """Analyze dataset quality metrics."""
    logger.info("Analyzing dataset quality...")

    # Sample dataset
    if hasattr(dataset, "__len__"):
        sample_size = min(sample_size, len(dataset))
        indices = list(range(sample_size))
    else:
        # Streaming dataset
        samples = []
        for i, sample in enumerate(dataset):
            if i >= sample_size:
                break
            samples.append(sample)
        dataset = samples
        sample_size = len(samples)

    analysis = {
        "total_samples": sample_size,
        "domains": {},
        "difficulty_distribution": {},
        "length_stats": {},
        "thoughts_coverage": 0.0,
        "conversation_turns": [],
    }

    domains = []
    difficulties = []
    lengths = []
    thoughts_counts = []
    turns = []

    for sample in dataset:
        # Domain
        domain = sample.get("domain", detect_domain(sample.get("conversations", [])))
        domains.append(domain)

        # Difficulty
        difficulty = sample.get("difficulty", estimate_difficulty(sample.get("conversations", []), sample.get("thoughts", "")))
        difficulties.append(difficulty)

        # Length
        text = sample.get("text", "")
        if not text and "conversations" in sample:
            text = preprocess_conversation(sample["conversations"])["text"]
        lengths.append(len(text.split()))

        # Thoughts
        if "thoughts" in sample and sample["thoughts"]:
            thoughts_counts.append(1)
        else:
            thoughts_counts.append(0)

        # Turns
        if "conversations" in sample and isinstance(sample["conversations"], list):
            turns.append(len(sample["conversations"]))

    # Compute statistics
    from collections import Counter

    analysis["domains"] = dict(Counter(domains))
    analysis["difficulty_distribution"] = {
        "mean": float(np.mean(difficulties)) if difficulties else 0.0,
        "std": float(np.std(difficulties)) if difficulties else 0.0,
        "histogram": np.histogram(difficulties, bins=10, range=(0, 1))[0].tolist(),
    }
    analysis["length_stats"] = compute_length_statistics(lengths)
    analysis["thoughts_coverage"] = sum(thoughts_counts) / len(thoughts_counts) if thoughts_counts else 0.0
    analysis["conversation_turns"] = {
        "mean": float(np.mean(turns)) if turns else 0.0,
        "max": int(max(turns)) if turns else 0,
    }

    logger.info(f"Dataset analysis complete: {analysis}")
    return analysis