Zenith-28b-p300-V1 / data /preprocessing.py
Zandy-Wandy's picture
Upload Zenith-28b-V1-Tenstorrent-Blackhole-p300 model
8944ef7 verified
"""Advanced Preprocessing for OpenThoughts and Custom Datasets"""
import json
import logging
import re
from typing import Any, Dict, List, Optional, Tuple
logger = logging.getLogger(__name__)
# Special token markers
THOUGHT_START = "<think>"
THOUGHT_END = "</think>"
USER_START = "<user>"
USER_END = "</user>"
ASSISTANT_START = "<assistant>"
ASSISTANT_END = "</assistant>"
SYSTEM_START = "<system>"
SYSTEM_END = "</system>"
def preprocess_conversation(
conversations: Any,
include_thoughts: bool = True,
include_reasoning: bool = True,
) -> Dict[str, Any]:
"""Preprocess conversation data into training format."""
if isinstance(conversations, str):
try:
conversations = json.loads(conversations)
except json.JSONDecodeError:
return {"text": conversations, "conversations": []}
if not isinstance(conversations, list):
return {"text": str(conversations), "conversations": []}
processed_messages = []
thoughts = []
reasoning = ""
for msg in conversations:
if not isinstance(msg, dict):
continue
role = msg.get("role", "").lower()
content = msg.get("content", "")
if not content:
continue
# Extract thoughts if present
if include_thoughts and THOUGHT_START in content:
thought_parts = re.findall(r'<think>(.*?)</think>', content, re.DOTALL)
thoughts.extend([t.strip() for t in thought_parts if t.strip()])
# Remove thought tags from content
content = re.sub(r'<think>.*?</think>', '', content, flags=re.DOTALL).strip()
# Format with special tokens
if role == "user":
formatted = f"{USER_START} {content} {USER_END}"
elif role == "assistant":
formatted = f"{ASSISTANT_START} {content} {ASSISTANT_END}"
elif role == "system":
formatted = f"{SYSTEM_START} {content} {SYSTEM_END}"
else:
formatted = content
processed_messages.append({
"role": role,
"content": content,
"formatted": formatted,
})
# Combine into single text
text = "\n".join(msg["formatted"] for msg in processed_messages)
result = {
"text": text,
"conversations": processed_messages,
}
if include_thoughts and thoughts:
result["thoughts"] = " ".join(thoughts)
if include_reasoning and reasoning:
result["reasoning"] = reasoning
return result
def extract_thoughts(text: str) -> str:
"""Extract chain-of-thought from text."""
pattern = re.compile(r'<think>(.*?)</think>', re.DOTALL)
thoughts = pattern.findall(text)
return " ".join(t.strip() for t in thoughts if t.strip())
def format_for_training(
sample: Dict[str, Any],
include_thoughts: bool = True,
include_reasoning: bool = True,
) -> str:
"""Format sample for model training."""
if "text" in sample:
text = sample["text"]
elif "conversations" in sample:
text = preprocess_conversation(sample["conversations"], include_thoughts, include_reasoning)["text"]
elif "content" in sample:
text = sample["content"]
else:
text = ""
# Add thoughts if available and requested
if include_thoughts and "thoughts" in sample and sample["thoughts"]:
text += f"\n{THOUGHT_START} {sample['thoughts']} {THOUGHT_END}"
return text
def detect_domain(conversations: Any) -> str:
"""Detect domain of conversation based on content."""
if isinstance(conversations, str):
try:
conversations = json.loads(conversations)
except:
conversations = []
text = ""
for msg in conversations:
if isinstance(msg, dict):
text += msg.get("content", "") + " "
text = text.lower()
# Domain keywords
domain_keywords = {
"code": ["def ", "class ", "import ", "function", "return", "if __name__", "```python", "```java", "```cpp"],
"mathematics": ["equation", "theorem", "proof", "calculate", "solve", "integral", "derivative", "matrix", "vector"],
"science": ["experiment", "hypothesis", "theory", "data", "analysis", "chemical", "physical", "biological"],
"reasoning": ["because", "therefore", "thus", "hence", "since", "logic", "deduce", "infer"],
"dialogue": ["how are you", "what do you think", "please help", "thank you", "could you"],
}
scores = {}
for domain, keywords in domain_keywords.items():
score = sum(1 for kw in keywords if kw in text)
scores[domain] = score
if not scores:
return "unknown"
return max(scores, key=scores.get)
def estimate_difficulty(conversations: Any, thoughts: str = "") -> float:
"""Estimate difficulty on scale 0-1."""
if isinstance(conversations, str):
try:
conversations = json.loads(conversations)
except:
conversations = []
text = ""
for msg in conversations:
if isinstance(msg, dict):
text += msg.get("content", "") + " "
text += thoughts
# Features for difficulty
features = {
"length": len(text.split()),
"technical_terms": len(re.findall(r'\b[A-Z][a-z]+(?:[A-Z][a-z]+)+\b', text)), # CamelCase
"code_blocks": len(re.findall(r'```[\s\S]*?```', text)),
"math_symbols": len(re.findall[r'[=+\-*/<>≤≥≠∈∉⊂⊆∪∩]', text]),
"reasoning_markers": len(re.findall(r'\b(because|therefore|thus|hence|since)\b', text, re.IGNORECASE)),
}
# Normalize and combine
difficulty = (
min(features["length"] / 500, 1.0) * 0.3 +
min(features["technical_terms"] / 20, 1.0) * 0.25 +
min(features["code_blocks"] / 3, 1.0) * 0.25 +
min(features["math_symbols"] / 10, 1.0) * 0.1 +
min(features["reasoning_markers"] / 5, 1.0) * 0.1
)
return min(difficulty, 1.0)
def clean_text(text: str) -> str:
"""Clean and normalize text."""
# Remove excessive whitespace
text = re.sub(r'\s+', ' ', text)
# Remove control characters
text = re.sub(r'[\x00-\x1F\x7F]', '', text)
# Normalize quotes
text = text.replace('"', '"').replace('"', '"')
text = text.replace(''', "'").replace(''', "'")
# Strip
text = text.strip()
return text
def truncate_with_overlap(
text: str,
max_length: int,
stride: int,
tokenizer: Any,
) -> List[Dict[str, Any]]:
"""Truncate long text with overlapping windows."""
tokens = tokenizer.encode(text, add_special_tokens=False)
if len(tokens) <= max_length:
return [{"input_ids": tokens, "attention_mask": [1] * len(tokens)}]
chunks = []
start = 0
while start < len(tokens):
end = min(start + max_length, len(tokens))
chunk_tokens = tokens[start:end]
chunks.append({
"input_ids": chunk_tokens,
"attention_mask": [1] * len(chunk_tokens),
})
if end >= len(tokens):
break
start += stride
return chunks
def compute_length_statistics(lengths: List[int]) -> Dict[str, float]:
"""Compute statistics for length distribution."""
import numpy as np
if not lengths:
return {}
arr = np.array(lengths)
return {
"mean": float(np.mean(arr)),
"std": float(np.std(arr)),
"min": float(np.min(arr)),
"max": float(np.max(arr)),
"p50": float(np.percentile(arr, 50)),
"p90": float(np.percentile(arr, 90)),
"p95": float(np.percentile(arr, 95)),
"p99": float(np.percentile(arr, 99)),
}
def analyze_dataset_quality(dataset: Any, sample_size: int = 1000) -> Dict[str, Any]:
"""Analyze dataset quality metrics."""
logger.info("Analyzing dataset quality...")
# Sample dataset
if hasattr(dataset, "__len__"):
sample_size = min(sample_size, len(dataset))
indices = list(range(sample_size))
else:
# Streaming dataset
samples = []
for i, sample in enumerate(dataset):
if i >= sample_size:
break
samples.append(sample)
dataset = samples
sample_size = len(samples)
analysis = {
"total_samples": sample_size,
"domains": {},
"difficulty_distribution": {},
"length_stats": {},
"thoughts_coverage": 0.0,
"conversation_turns": [],
}
domains = []
difficulties = []
lengths = []
thoughts_counts = []
turns = []
for sample in dataset:
# Domain
domain = sample.get("domain", detect_domain(sample.get("conversations", [])))
domains.append(domain)
# Difficulty
difficulty = sample.get("difficulty", estimate_difficulty(sample.get("conversations", []), sample.get("thoughts", "")))
difficulties.append(difficulty)
# Length
text = sample.get("text", "")
if not text and "conversations" in sample:
text = preprocess_conversation(sample["conversations"])["text"]
lengths.append(len(text.split()))
# Thoughts
if "thoughts" in sample and sample["thoughts"]:
thoughts_counts.append(1)
else:
thoughts_counts.append(0)
# Turns
if "conversations" in sample and isinstance(sample["conversations"], list):
turns.append(len(sample["conversations"]))
# Compute statistics
from collections import Counter
analysis["domains"] = dict(Counter(domains))
analysis["difficulty_distribution"] = {
"mean": float(np.mean(difficulties)) if difficulties else 0.0,
"std": float(np.std(difficulties)) if difficulties else 0.0,
"histogram": np.histogram(difficulties, bins=10, range=(0, 1))[0].tolist(),
}
analysis["length_stats"] = compute_length_statistics(lengths)
analysis["thoughts_coverage"] = sum(thoughts_counts) / len(thoughts_counts) if thoughts_counts else 0.0
analysis["conversation_turns"] = {
"mean": float(np.mean(turns)) if turns else 0.0,
"max": int(max(turns)) if turns else 0,
}
logger.info(f"Dataset analysis complete: {analysis}")
return analysis