tritesh's picture
Upload folder using huggingface_hub
0433390 verified
"""
Data generation utilities for DFlash training.
Generates training data by running the target model on prompts,
creating {prompt, response} pairs for drafter training.
"""
import json
from pathlib import Path
from typing import Optional, List, Dict, Any
import mlx.core as mx
def generate_training_data(
target_model,
tokenizer,
prompts_dataset: str,
output_path: str,
max_new_tokens: int = 2048,
temperature: float = 0.0,
num_samples: Optional[int] = None,
system_prompt: Optional[str] = None,
) -> str:
"""Generate training data by running target model on prompts.
This creates the supervised data that DFlash drafters need:
pairs of (prompt, target_model_response).
Args:
target_model: MLX target model
tokenizer: Tokenizer
prompts_dataset: HF dataset name or path to prompts file
output_path: Output JSONL file path
max_new_tokens: Max tokens per response
temperature: Generation temperature (0 for greedy)
num_samples: Max number of samples to generate (None = all)
system_prompt: Optional system prompt
Returns:
Path to output file
"""
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
# Load prompts
prompts = _load_prompts(prompts_dataset)
if num_samples:
prompts = prompts[:num_samples]
print(f"[DataGen] Generating {len(prompts)} responses...")
with open(output_path, "w") as f:
for i, prompt in enumerate(prompts):
print(f"[DataGen] Sample {i+1}/{len(prompts)}...")
# Generate response with target model
response = _generate_with_model(
model=target_model,
tokenizer=tokenizer,
prompt=prompt,
max_new_tokens=max_new_tokens,
temperature=temperature,
system_prompt=system_prompt,
)
# Save sample
sample = {
"prompt": prompt,
"response": response,
"model": getattr(target_model, "config", {}).get("_name_or_path", "unknown"),
}
f.write(json.dumps(sample) + "\n")
print(f"[DataGen] Done! Saved to {output_path}")
return str(output_path)
def _load_prompts(dataset: str) -> List[str]:
"""Load prompts from dataset or file."""
import json
from pathlib import Path
path = Path(dataset)
if path.exists():
# Local file
prompts = []
with open(path, "r") as f:
for line in f:
data = json.loads(line)
prompt = data.get("prompt", data.get("input", data.get("question", "")))
if prompt:
prompts.append(prompt)
return prompts
# Try Hugging Face dataset
try:
from datasets import load_dataset
ds = load_dataset(dataset, split="train")
prompts = []
for item in ds:
prompt = item.get("prompt", item.get("input", item.get("question", item.get("text", ""))))
if prompt:
prompts.append(str(prompt))
return prompts
except Exception as e:
print(f"[DataGen] Failed to load dataset: {e}")
return []
def _generate_with_model(
model,
tokenizer,
prompt: str,
max_new_tokens: int,
temperature: float = 0.0,
system_prompt: Optional[str] = None,
) -> str:
"""Generate text with an MLX model."""
# Build prompt
if system_prompt and hasattr(tokenizer, 'apply_chat_template'):
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt},
]
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
elif hasattr(tokenizer, 'apply_chat_template'):
messages = [{"role": "user", "content": prompt}]
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
else:
text = prompt
# Tokenize
input_ids = mx.array(tokenizer.encode(text))
input_ids = input_ids.reshape(1, -1)
# Generate
generated = []
for _ in range(max_new_tokens):
if hasattr(model, '__call__'):
result = model(input_ids)
logits = result[0] if isinstance(result, tuple) else result
else:
logits = model(input_ids)
# Sample next token
next_logits = logits[:, -1, :]
if temperature < 1e-5:
next_token = mx.argmax(next_logits, axis=-1)
else:
probs = mx.softmax(next_logits / temperature, axis=-1)
next_token = mx.random.categorical(mx.log(probs))
generated.append(int(next_token[0]))
input_ids = mx.concatenate([input_ids, next_token.reshape(1, 1)], axis=1)
# Check for EOS
if hasattr(tokenizer, 'eos_token_id') and int(next_token[0]) == tokenizer.eos_token_id:
break
# Decode
return tokenizer.decode(generated)
def create_mixed_training_data(
output_path: str,
math_ratio: float = 0.30,
code_ratio: float = 0.20,
chat_ratio: float = 0.50,
total_samples: int = 100000,
) -> str:
"""Create a mixed training dataset from public sources.
This replicates the paper's data mixture recipe:
- 50% instruction/chat (UltraChat, ShareGPT)
- 30% math/reasoning (GSM8K, MATH)
- 20% code (HumanEval, MBPP)
Args:
output_path: Output JSONL path
math_ratio: Fraction of math samples
code_ratio: Fraction of code samples
chat_ratio: Fraction of chat samples
total_samples: Total number of samples
Returns:
Path to output file
"""
from datasets import load_dataset
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
samples = []
# Chat data
chat_count = int(total_samples * chat_ratio)
try:
print("[DataGen] Loading UltraChat...")
ds = load_dataset("HuggingFaceH4/ultrachat_200k", split="train_sft")
for i, item in enumerate(ds):
if i >= chat_count:
break
messages = item.get("messages", [])
if len(messages) >= 2:
prompt = messages[-2].get("content", "")
response = messages[-1].get("content", "")
if prompt and response:
samples.append({"prompt": prompt, "response": response, "category": "chat"})
except Exception as e:
print(f"[DataGen] UltraChat failed: {e}")
# Math data
math_count = int(total_samples * math_ratio)
try:
print("[DataGen] Loading GSM8K...")
ds = load_dataset("openai/gsm8k", "main", split="train")
for i, item in enumerate(ds):
if i >= math_count:
break
prompt = item.get("question", "")
response = item.get("answer", "")
if prompt and response:
samples.append({"prompt": prompt, "response": response, "category": "math"})
except Exception as e:
print(f"[DataGen] GSM8K failed: {e}")
# Code data
code_count = int(total_samples * code_ratio)
try:
print("[DataGen] Loading MBPP...")
ds = load_dataset("mbpp", split="train")
for i, item in enumerate(ds):
if i >= code_count:
break
prompt = item.get("text", item.get("prompt", ""))
response = item.get("code", item.get("canonical_solution", ""))
if prompt and response:
samples.append({"prompt": prompt, "response": response, "category": "code"})
except Exception as e:
print(f"[DataGen] MBPP failed: {e}")
# Save
with open(output_path, "w") as f:
for sample in samples:
f.write(json.dumps(sample) + "\n")
print(f"[DataGen] Created {len(samples)} mixed samples at {output_path}")
return str(output_path)