| """Advanced Preprocessing for OpenThoughts and Custom Datasets"""
|
|
|
| import json
|
| import logging
|
| import re
|
| from typing import Any, Dict, List, Optional, Tuple
|
|
|
| logger = logging.getLogger(__name__)
|
|
|
|
|
| THOUGHT_START = "<think>"
|
| THOUGHT_END = "</think>"
|
| USER_START = "<user>"
|
| USER_END = "</user>"
|
| ASSISTANT_START = "<assistant>"
|
| ASSISTANT_END = "</assistant>"
|
| SYSTEM_START = "<system>"
|
| SYSTEM_END = "</system>"
|
|
|
|
|
| def preprocess_conversation(
|
| conversations: Any,
|
| include_thoughts: bool = True,
|
| include_reasoning: bool = True,
|
| ) -> Dict[str, Any]:
|
| """Preprocess conversation data into training format."""
|
| if isinstance(conversations, str):
|
| try:
|
| conversations = json.loads(conversations)
|
| except json.JSONDecodeError:
|
| return {"text": conversations, "conversations": []}
|
|
|
| if not isinstance(conversations, list):
|
| return {"text": str(conversations), "conversations": []}
|
|
|
| processed_messages = []
|
| thoughts = []
|
| reasoning = ""
|
|
|
| for msg in conversations:
|
| if not isinstance(msg, dict):
|
| continue
|
|
|
| role = msg.get("role", "").lower()
|
| content = msg.get("content", "")
|
|
|
| if not content:
|
| continue
|
|
|
|
|
| if include_thoughts and THOUGHT_START in content:
|
| thought_parts = re.findall(r'<think>(.*?)</think>', content, re.DOTALL)
|
| thoughts.extend([t.strip() for t in thought_parts if t.strip()])
|
|
|
| content = re.sub(r'<think>.*?</think>', '', content, flags=re.DOTALL).strip()
|
|
|
|
|
| if role == "user":
|
| formatted = f"{USER_START} {content} {USER_END}"
|
| elif role == "assistant":
|
| formatted = f"{ASSISTANT_START} {content} {ASSISTANT_END}"
|
| elif role == "system":
|
| formatted = f"{SYSTEM_START} {content} {SYSTEM_END}"
|
| else:
|
| formatted = content
|
|
|
| processed_messages.append({
|
| "role": role,
|
| "content": content,
|
| "formatted": formatted,
|
| })
|
|
|
|
|
| text = "\n".join(msg["formatted"] for msg in processed_messages)
|
|
|
| result = {
|
| "text": text,
|
| "conversations": processed_messages,
|
| }
|
|
|
| if include_thoughts and thoughts:
|
| result["thoughts"] = " ".join(thoughts)
|
|
|
| if include_reasoning and reasoning:
|
| result["reasoning"] = reasoning
|
|
|
| return result
|
|
|
|
|
| def extract_thoughts(text: str) -> str:
|
| """Extract chain-of-thought from text."""
|
| pattern = re.compile(r'<think>(.*?)</think>', re.DOTALL)
|
| thoughts = pattern.findall(text)
|
| return " ".join(t.strip() for t in thoughts if t.strip())
|
|
|
|
|
| def format_for_training(
|
| sample: Dict[str, Any],
|
| include_thoughts: bool = True,
|
| include_reasoning: bool = True,
|
| ) -> str:
|
| """Format sample for model training."""
|
| if "text" in sample:
|
| text = sample["text"]
|
| elif "conversations" in sample:
|
| text = preprocess_conversation(sample["conversations"], include_thoughts, include_reasoning)["text"]
|
| elif "content" in sample:
|
| text = sample["content"]
|
| else:
|
| text = ""
|
|
|
|
|
| if include_thoughts and "thoughts" in sample and sample["thoughts"]:
|
| text += f"\n{THOUGHT_START} {sample['thoughts']} {THOUGHT_END}"
|
|
|
| return text
|
|
|
|
|
| def detect_domain(conversations: Any) -> str:
|
| """Detect domain of conversation based on content."""
|
| if isinstance(conversations, str):
|
| try:
|
| conversations = json.loads(conversations)
|
| except:
|
| conversations = []
|
|
|
| text = ""
|
| for msg in conversations:
|
| if isinstance(msg, dict):
|
| text += msg.get("content", "") + " "
|
|
|
| text = text.lower()
|
|
|
|
|
| domain_keywords = {
|
| "code": ["def ", "class ", "import ", "function", "return", "if __name__", "```python", "```java", "```cpp"],
|
| "mathematics": ["equation", "theorem", "proof", "calculate", "solve", "integral", "derivative", "matrix", "vector"],
|
| "science": ["experiment", "hypothesis", "theory", "data", "analysis", "chemical", "physical", "biological"],
|
| "reasoning": ["because", "therefore", "thus", "hence", "since", "logic", "deduce", "infer"],
|
| "dialogue": ["how are you", "what do you think", "please help", "thank you", "could you"],
|
| }
|
|
|
| scores = {}
|
| for domain, keywords in domain_keywords.items():
|
| score = sum(1 for kw in keywords if kw in text)
|
| scores[domain] = score
|
|
|
| if not scores:
|
| return "unknown"
|
|
|
| return max(scores, key=scores.get)
|
|
|
|
|
| def estimate_difficulty(conversations: Any, thoughts: str = "") -> float:
|
| """Estimate difficulty on scale 0-1."""
|
| if isinstance(conversations, str):
|
| try:
|
| conversations = json.loads(conversations)
|
| except:
|
| conversations = []
|
|
|
| text = ""
|
| for msg in conversations:
|
| if isinstance(msg, dict):
|
| text += msg.get("content", "") + " "
|
|
|
| text += thoughts
|
|
|
|
|
| features = {
|
| "length": len(text.split()),
|
| "technical_terms": len(re.findall(r'\b[A-Z][a-z]+(?:[A-Z][a-z]+)+\b', text)),
|
| "code_blocks": len(re.findall(r'```[\s\S]*?```', text)),
|
| "math_symbols": len(re.findall[r'[=+\-*/<>≤≥≠∈∉⊂⊆∪∩]', text]),
|
| "reasoning_markers": len(re.findall(r'\b(because|therefore|thus|hence|since)\b', text, re.IGNORECASE)),
|
| }
|
|
|
|
|
| difficulty = (
|
| min(features["length"] / 500, 1.0) * 0.3 +
|
| min(features["technical_terms"] / 20, 1.0) * 0.25 +
|
| min(features["code_blocks"] / 3, 1.0) * 0.25 +
|
| min(features["math_symbols"] / 10, 1.0) * 0.1 +
|
| min(features["reasoning_markers"] / 5, 1.0) * 0.1
|
| )
|
|
|
| return min(difficulty, 1.0)
|
|
|
|
|
| def clean_text(text: str) -> str:
|
| """Clean and normalize text."""
|
|
|
| text = re.sub(r'\s+', ' ', text)
|
|
|
|
|
| text = re.sub(r'[\x00-\x1F\x7F]', '', text)
|
|
|
|
|
| text = text.replace('"', '"').replace('"', '"')
|
| text = text.replace(''', "'").replace(''', "'")
|
|
|
|
|
| text = text.strip()
|
|
|
| return text
|
|
|
|
|
| def truncate_with_overlap(
|
| text: str,
|
| max_length: int,
|
| stride: int,
|
| tokenizer: Any,
|
| ) -> List[Dict[str, Any]]:
|
| """Truncate long text with overlapping windows."""
|
| tokens = tokenizer.encode(text, add_special_tokens=False)
|
|
|
| if len(tokens) <= max_length:
|
| return [{"input_ids": tokens, "attention_mask": [1] * len(tokens)}]
|
|
|
| chunks = []
|
| start = 0
|
|
|
| while start < len(tokens):
|
| end = min(start + max_length, len(tokens))
|
| chunk_tokens = tokens[start:end]
|
|
|
| chunks.append({
|
| "input_ids": chunk_tokens,
|
| "attention_mask": [1] * len(chunk_tokens),
|
| })
|
|
|
| if end >= len(tokens):
|
| break
|
|
|
| start += stride
|
|
|
| return chunks
|
|
|
|
|
| def compute_length_statistics(lengths: List[int]) -> Dict[str, float]:
|
| """Compute statistics for length distribution."""
|
| import numpy as np
|
|
|
| if not lengths:
|
| return {}
|
|
|
| arr = np.array(lengths)
|
| return {
|
| "mean": float(np.mean(arr)),
|
| "std": float(np.std(arr)),
|
| "min": float(np.min(arr)),
|
| "max": float(np.max(arr)),
|
| "p50": float(np.percentile(arr, 50)),
|
| "p90": float(np.percentile(arr, 90)),
|
| "p95": float(np.percentile(arr, 95)),
|
| "p99": float(np.percentile(arr, 99)),
|
| }
|
|
|
|
|
| def analyze_dataset_quality(dataset: Any, sample_size: int = 1000) -> Dict[str, Any]:
|
| """Analyze dataset quality metrics."""
|
| logger.info("Analyzing dataset quality...")
|
|
|
|
|
| if hasattr(dataset, "__len__"):
|
| sample_size = min(sample_size, len(dataset))
|
| indices = list(range(sample_size))
|
| else:
|
|
|
| samples = []
|
| for i, sample in enumerate(dataset):
|
| if i >= sample_size:
|
| break
|
| samples.append(sample)
|
| dataset = samples
|
| sample_size = len(samples)
|
|
|
| analysis = {
|
| "total_samples": sample_size,
|
| "domains": {},
|
| "difficulty_distribution": {},
|
| "length_stats": {},
|
| "thoughts_coverage": 0.0,
|
| "conversation_turns": [],
|
| }
|
|
|
| domains = []
|
| difficulties = []
|
| lengths = []
|
| thoughts_counts = []
|
| turns = []
|
|
|
| for sample in dataset:
|
|
|
| domain = sample.get("domain", detect_domain(sample.get("conversations", [])))
|
| domains.append(domain)
|
|
|
|
|
| difficulty = sample.get("difficulty", estimate_difficulty(sample.get("conversations", []), sample.get("thoughts", "")))
|
| difficulties.append(difficulty)
|
|
|
|
|
| text = sample.get("text", "")
|
| if not text and "conversations" in sample:
|
| text = preprocess_conversation(sample["conversations"])["text"]
|
| lengths.append(len(text.split()))
|
|
|
|
|
| if "thoughts" in sample and sample["thoughts"]:
|
| thoughts_counts.append(1)
|
| else:
|
| thoughts_counts.append(0)
|
|
|
|
|
| if "conversations" in sample and isinstance(sample["conversations"], list):
|
| turns.append(len(sample["conversations"]))
|
|
|
|
|
| from collections import Counter
|
|
|
| analysis["domains"] = dict(Counter(domains))
|
| analysis["difficulty_distribution"] = {
|
| "mean": float(np.mean(difficulties)) if difficulties else 0.0,
|
| "std": float(np.std(difficulties)) if difficulties else 0.0,
|
| "histogram": np.histogram(difficulties, bins=10, range=(0, 1))[0].tolist(),
|
| }
|
| analysis["length_stats"] = compute_length_statistics(lengths)
|
| analysis["thoughts_coverage"] = sum(thoughts_counts) / len(thoughts_counts) if thoughts_counts else 0.0
|
| analysis["conversation_turns"] = {
|
| "mean": float(np.mean(turns)) if turns else 0.0,
|
| "max": int(max(turns)) if turns else 0,
|
| }
|
|
|
| logger.info(f"Dataset analysis complete: {analysis}")
|
| return analysis
|
|
|