"""Advanced Preprocessing for OpenThoughts and Custom Datasets""" import json import logging import re from typing import Any, Dict, List, Optional, Tuple logger = logging.getLogger(__name__) # Special token markers THOUGHT_START = "" THOUGHT_END = "" USER_START = "" USER_END = "" ASSISTANT_START = "" ASSISTANT_END = "" SYSTEM_START = "" SYSTEM_END = "" def preprocess_conversation( conversations: Any, include_thoughts: bool = True, include_reasoning: bool = True, ) -> Dict[str, Any]: """Preprocess conversation data into training format.""" if isinstance(conversations, str): try: conversations = json.loads(conversations) except json.JSONDecodeError: return {"text": conversations, "conversations": []} if not isinstance(conversations, list): return {"text": str(conversations), "conversations": []} processed_messages = [] thoughts = [] reasoning = "" for msg in conversations: if not isinstance(msg, dict): continue role = msg.get("role", "").lower() content = msg.get("content", "") if not content: continue # Extract thoughts if present if include_thoughts and THOUGHT_START in content: thought_parts = re.findall(r'(.*?)', content, re.DOTALL) thoughts.extend([t.strip() for t in thought_parts if t.strip()]) # Remove thought tags from content content = re.sub(r'.*?', '', content, flags=re.DOTALL).strip() # Format with special tokens if role == "user": formatted = f"{USER_START} {content} {USER_END}" elif role == "assistant": formatted = f"{ASSISTANT_START} {content} {ASSISTANT_END}" elif role == "system": formatted = f"{SYSTEM_START} {content} {SYSTEM_END}" else: formatted = content processed_messages.append({ "role": role, "content": content, "formatted": formatted, }) # Combine into single text text = "\n".join(msg["formatted"] for msg in processed_messages) result = { "text": text, "conversations": processed_messages, } if include_thoughts and thoughts: result["thoughts"] = " ".join(thoughts) if include_reasoning and reasoning: result["reasoning"] = reasoning return result def extract_thoughts(text: str) -> str: """Extract chain-of-thought from text.""" pattern = re.compile(r'(.*?)', re.DOTALL) thoughts = pattern.findall(text) return " ".join(t.strip() for t in thoughts if t.strip()) def format_for_training( sample: Dict[str, Any], include_thoughts: bool = True, include_reasoning: bool = True, ) -> str: """Format sample for model training.""" if "text" in sample: text = sample["text"] elif "conversations" in sample: text = preprocess_conversation(sample["conversations"], include_thoughts, include_reasoning)["text"] elif "content" in sample: text = sample["content"] else: text = "" # Add thoughts if available and requested if include_thoughts and "thoughts" in sample and sample["thoughts"]: text += f"\n{THOUGHT_START} {sample['thoughts']} {THOUGHT_END}" return text def detect_domain(conversations: Any) -> str: """Detect domain of conversation based on content.""" if isinstance(conversations, str): try: conversations = json.loads(conversations) except: conversations = [] text = "" for msg in conversations: if isinstance(msg, dict): text += msg.get("content", "") + " " text = text.lower() # Domain keywords domain_keywords = { "code": ["def ", "class ", "import ", "function", "return", "if __name__", "```python", "```java", "```cpp"], "mathematics": ["equation", "theorem", "proof", "calculate", "solve", "integral", "derivative", "matrix", "vector"], "science": ["experiment", "hypothesis", "theory", "data", "analysis", "chemical", "physical", "biological"], "reasoning": ["because", "therefore", "thus", "hence", "since", "logic", "deduce", "infer"], "dialogue": ["how are you", "what do you think", "please help", "thank you", "could you"], } scores = {} for domain, keywords in domain_keywords.items(): score = sum(1 for kw in keywords if kw in text) scores[domain] = score if not scores: return "unknown" return max(scores, key=scores.get) def estimate_difficulty(conversations: Any, thoughts: str = "") -> float: """Estimate difficulty on scale 0-1.""" if isinstance(conversations, str): try: conversations = json.loads(conversations) except: conversations = [] text = "" for msg in conversations: if isinstance(msg, dict): text += msg.get("content", "") + " " text += thoughts # Features for difficulty features = { "length": len(text.split()), "technical_terms": len(re.findall(r'\b[A-Z][a-z]+(?:[A-Z][a-z]+)+\b', text)), # CamelCase "code_blocks": len(re.findall(r'```[\s\S]*?```', text)), "math_symbols": len(re.findall[r'[=+\-*/<>≤≥≠∈∉⊂⊆∪∩]', text]), "reasoning_markers": len(re.findall(r'\b(because|therefore|thus|hence|since)\b', text, re.IGNORECASE)), } # Normalize and combine difficulty = ( min(features["length"] / 500, 1.0) * 0.3 + min(features["technical_terms"] / 20, 1.0) * 0.25 + min(features["code_blocks"] / 3, 1.0) * 0.25 + min(features["math_symbols"] / 10, 1.0) * 0.1 + min(features["reasoning_markers"] / 5, 1.0) * 0.1 ) return min(difficulty, 1.0) def clean_text(text: str) -> str: """Clean and normalize text.""" # Remove excessive whitespace text = re.sub(r'\s+', ' ', text) # Remove control characters text = re.sub(r'[\x00-\x1F\x7F]', '', text) # Normalize quotes text = text.replace('"', '"').replace('"', '"') text = text.replace(''', "'").replace(''', "'") # Strip text = text.strip() return text def truncate_with_overlap( text: str, max_length: int, stride: int, tokenizer: Any, ) -> List[Dict[str, Any]]: """Truncate long text with overlapping windows.""" tokens = tokenizer.encode(text, add_special_tokens=False) if len(tokens) <= max_length: return [{"input_ids": tokens, "attention_mask": [1] * len(tokens)}] chunks = [] start = 0 while start < len(tokens): end = min(start + max_length, len(tokens)) chunk_tokens = tokens[start:end] chunks.append({ "input_ids": chunk_tokens, "attention_mask": [1] * len(chunk_tokens), }) if end >= len(tokens): break start += stride return chunks def compute_length_statistics(lengths: List[int]) -> Dict[str, float]: """Compute statistics for length distribution.""" import numpy as np if not lengths: return {} arr = np.array(lengths) return { "mean": float(np.mean(arr)), "std": float(np.std(arr)), "min": float(np.min(arr)), "max": float(np.max(arr)), "p50": float(np.percentile(arr, 50)), "p90": float(np.percentile(arr, 90)), "p95": float(np.percentile(arr, 95)), "p99": float(np.percentile(arr, 99)), } def analyze_dataset_quality(dataset: Any, sample_size: int = 1000) -> Dict[str, Any]: """Analyze dataset quality metrics.""" logger.info("Analyzing dataset quality...") # Sample dataset if hasattr(dataset, "__len__"): sample_size = min(sample_size, len(dataset)) indices = list(range(sample_size)) else: # Streaming dataset samples = [] for i, sample in enumerate(dataset): if i >= sample_size: break samples.append(sample) dataset = samples sample_size = len(samples) analysis = { "total_samples": sample_size, "domains": {}, "difficulty_distribution": {}, "length_stats": {}, "thoughts_coverage": 0.0, "conversation_turns": [], } domains = [] difficulties = [] lengths = [] thoughts_counts = [] turns = [] for sample in dataset: # Domain domain = sample.get("domain", detect_domain(sample.get("conversations", []))) domains.append(domain) # Difficulty difficulty = sample.get("difficulty", estimate_difficulty(sample.get("conversations", []), sample.get("thoughts", ""))) difficulties.append(difficulty) # Length text = sample.get("text", "") if not text and "conversations" in sample: text = preprocess_conversation(sample["conversations"])["text"] lengths.append(len(text.split())) # Thoughts if "thoughts" in sample and sample["thoughts"]: thoughts_counts.append(1) else: thoughts_counts.append(0) # Turns if "conversations" in sample and isinstance(sample["conversations"], list): turns.append(len(sample["conversations"])) # Compute statistics from collections import Counter analysis["domains"] = dict(Counter(domains)) analysis["difficulty_distribution"] = { "mean": float(np.mean(difficulties)) if difficulties else 0.0, "std": float(np.std(difficulties)) if difficulties else 0.0, "histogram": np.histogram(difficulties, bins=10, range=(0, 1))[0].tolist(), } analysis["length_stats"] = compute_length_statistics(lengths) analysis["thoughts_coverage"] = sum(thoughts_counts) / len(thoughts_counts) if thoughts_counts else 0.0 analysis["conversation_turns"] = { "mean": float(np.mean(turns)) if turns else 0.0, "max": int(max(turns)) if turns else 0, } logger.info(f"Dataset analysis complete: {analysis}") return analysis