| """ |
| Data generation utilities for DFlash training. |
| |
| Generates training data by running the target model on prompts, |
| creating {prompt, response} pairs for drafter training. |
| """ |
|
|
| import json |
| from pathlib import Path |
| from typing import Optional, List, Dict, Any |
| import mlx.core as mx |
|
|
|
|
| def generate_training_data( |
| target_model, |
| tokenizer, |
| prompts_dataset: str, |
| output_path: str, |
| max_new_tokens: int = 2048, |
| temperature: float = 0.0, |
| num_samples: Optional[int] = None, |
| system_prompt: Optional[str] = None, |
| ) -> str: |
| """Generate training data by running target model on prompts. |
| |
| This creates the supervised data that DFlash drafters need: |
| pairs of (prompt, target_model_response). |
| |
| Args: |
| target_model: MLX target model |
| tokenizer: Tokenizer |
| prompts_dataset: HF dataset name or path to prompts file |
| output_path: Output JSONL file path |
| max_new_tokens: Max tokens per response |
| temperature: Generation temperature (0 for greedy) |
| num_samples: Max number of samples to generate (None = all) |
| system_prompt: Optional system prompt |
| |
| Returns: |
| Path to output file |
| """ |
| output_path = Path(output_path) |
| output_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
| |
| prompts = _load_prompts(prompts_dataset) |
| if num_samples: |
| prompts = prompts[:num_samples] |
|
|
| print(f"[DataGen] Generating {len(prompts)} responses...") |
|
|
| with open(output_path, "w") as f: |
| for i, prompt in enumerate(prompts): |
| print(f"[DataGen] Sample {i+1}/{len(prompts)}...") |
|
|
| |
| response = _generate_with_model( |
| model=target_model, |
| tokenizer=tokenizer, |
| prompt=prompt, |
| max_new_tokens=max_new_tokens, |
| temperature=temperature, |
| system_prompt=system_prompt, |
| ) |
|
|
| |
| sample = { |
| "prompt": prompt, |
| "response": response, |
| "model": getattr(target_model, "config", {}).get("_name_or_path", "unknown"), |
| } |
| f.write(json.dumps(sample) + "\n") |
|
|
| print(f"[DataGen] Done! Saved to {output_path}") |
| return str(output_path) |
|
|
|
|
| def _load_prompts(dataset: str) -> List[str]: |
| """Load prompts from dataset or file.""" |
| import json |
| from pathlib import Path |
|
|
| path = Path(dataset) |
| if path.exists(): |
| |
| prompts = [] |
| with open(path, "r") as f: |
| for line in f: |
| data = json.loads(line) |
| prompt = data.get("prompt", data.get("input", data.get("question", ""))) |
| if prompt: |
| prompts.append(prompt) |
| return prompts |
|
|
| |
| try: |
| from datasets import load_dataset |
| ds = load_dataset(dataset, split="train") |
| prompts = [] |
| for item in ds: |
| prompt = item.get("prompt", item.get("input", item.get("question", item.get("text", "")))) |
| if prompt: |
| prompts.append(str(prompt)) |
| return prompts |
| except Exception as e: |
| print(f"[DataGen] Failed to load dataset: {e}") |
| return [] |
|
|
|
|
| def _generate_with_model( |
| model, |
| tokenizer, |
| prompt: str, |
| max_new_tokens: int, |
| temperature: float = 0.0, |
| system_prompt: Optional[str] = None, |
| ) -> str: |
| """Generate text with an MLX model.""" |
| |
| if system_prompt and hasattr(tokenizer, 'apply_chat_template'): |
| messages = [ |
| {"role": "system", "content": system_prompt}, |
| {"role": "user", "content": prompt}, |
| ] |
| text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
| elif hasattr(tokenizer, 'apply_chat_template'): |
| messages = [{"role": "user", "content": prompt}] |
| text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
| else: |
| text = prompt |
|
|
| |
| input_ids = mx.array(tokenizer.encode(text)) |
| input_ids = input_ids.reshape(1, -1) |
|
|
| |
| generated = [] |
| for _ in range(max_new_tokens): |
| if hasattr(model, '__call__'): |
| result = model(input_ids) |
| logits = result[0] if isinstance(result, tuple) else result |
| else: |
| logits = model(input_ids) |
|
|
| |
| next_logits = logits[:, -1, :] |
| if temperature < 1e-5: |
| next_token = mx.argmax(next_logits, axis=-1) |
| else: |
| probs = mx.softmax(next_logits / temperature, axis=-1) |
| next_token = mx.random.categorical(mx.log(probs)) |
|
|
| generated.append(int(next_token[0])) |
| input_ids = mx.concatenate([input_ids, next_token.reshape(1, 1)], axis=1) |
|
|
| |
| if hasattr(tokenizer, 'eos_token_id') and int(next_token[0]) == tokenizer.eos_token_id: |
| break |
|
|
| |
| return tokenizer.decode(generated) |
|
|
|
|
| def create_mixed_training_data( |
| output_path: str, |
| math_ratio: float = 0.30, |
| code_ratio: float = 0.20, |
| chat_ratio: float = 0.50, |
| total_samples: int = 100000, |
| ) -> str: |
| """Create a mixed training dataset from public sources. |
| |
| This replicates the paper's data mixture recipe: |
| - 50% instruction/chat (UltraChat, ShareGPT) |
| - 30% math/reasoning (GSM8K, MATH) |
| - 20% code (HumanEval, MBPP) |
| |
| Args: |
| output_path: Output JSONL path |
| math_ratio: Fraction of math samples |
| code_ratio: Fraction of code samples |
| chat_ratio: Fraction of chat samples |
| total_samples: Total number of samples |
| |
| Returns: |
| Path to output file |
| """ |
| from datasets import load_dataset |
|
|
| output_path = Path(output_path) |
| output_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
| samples = [] |
|
|
| |
| chat_count = int(total_samples * chat_ratio) |
| try: |
| print("[DataGen] Loading UltraChat...") |
| ds = load_dataset("HuggingFaceH4/ultrachat_200k", split="train_sft") |
| for i, item in enumerate(ds): |
| if i >= chat_count: |
| break |
| messages = item.get("messages", []) |
| if len(messages) >= 2: |
| prompt = messages[-2].get("content", "") |
| response = messages[-1].get("content", "") |
| if prompt and response: |
| samples.append({"prompt": prompt, "response": response, "category": "chat"}) |
| except Exception as e: |
| print(f"[DataGen] UltraChat failed: {e}") |
|
|
| |
| math_count = int(total_samples * math_ratio) |
| try: |
| print("[DataGen] Loading GSM8K...") |
| ds = load_dataset("openai/gsm8k", "main", split="train") |
| for i, item in enumerate(ds): |
| if i >= math_count: |
| break |
| prompt = item.get("question", "") |
| response = item.get("answer", "") |
| if prompt and response: |
| samples.append({"prompt": prompt, "response": response, "category": "math"}) |
| except Exception as e: |
| print(f"[DataGen] GSM8K failed: {e}") |
|
|
| |
| code_count = int(total_samples * code_ratio) |
| try: |
| print("[DataGen] Loading MBPP...") |
| ds = load_dataset("mbpp", split="train") |
| for i, item in enumerate(ds): |
| if i >= code_count: |
| break |
| prompt = item.get("text", item.get("prompt", "")) |
| response = item.get("code", item.get("canonical_solution", "")) |
| if prompt and response: |
| samples.append({"prompt": prompt, "response": response, "category": "code"}) |
| except Exception as e: |
| print(f"[DataGen] MBPP failed: {e}") |
|
|
| |
| with open(output_path, "w") as f: |
| for sample in samples: |
| f.write(json.dumps(sample) + "\n") |
|
|
| print(f"[DataGen] Created {len(samples)} mixed samples at {output_path}") |
| return str(output_path) |
|
|