| """ |
| Phase 4 β Efficient Fine-Tuning + Synthetic Data Generation |
| ============================================================= |
| β’ Synthetic data generation from gap categories (LLM-driven) |
| β’ LoRA / QLoRA fine-tuning using PEFT + TRL SFTTrainer |
| β’ Delta adapter extraction (merge-ready LoRA weights) |
| β’ Iterative improvement loop: eval β gap detect β generate β fine-tune β re-eval |
| |
| Usage: |
| python -m phase4_finetuning.finetune --base mistralai/Mistral-7B-v0.3 --gaps factual_recall |
| """ |
|
|
| from __future__ import annotations |
|
|
| import json |
| import re |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Optional |
|
|
| import torch |
| import typer |
| from rich.console import Console |
| from rich.progress import track |
|
|
| from configs.settings import ( |
| FT_BASE_MODEL, FT_EPOCHS, FT_LR, FT_WARMUP_RATIO, FT_SAVE_STEPS, |
| CFG, ADAPTERS_DIR, DATA_DIR, HF_TOKEN |
| ) |
| from utils.logger import logger |
|
|
| app = typer.Typer(help="Phase 4: Fine-tuning & synthetic data generation") |
| console = Console() |
|
|
|
|
| |
| |
| |
|
|
| GAP_PROMPTS: dict[str, str] = { |
| "factual_recall": """Generate {n} high-quality QA pairs testing factual recall. |
| Format each as JSON: {{"context": "...", "question": "...", "answer": "..."}} |
| Focus on: historical dates, scientific facts, geography, key figures. |
| Return a JSON array only.""", |
|
|
| "multi_step_reasoning": """Generate {n} QA pairs requiring multi-step reasoning. |
| Format each as JSON: {{"context": "...", "question": "...", "answer": "..."}} |
| Each answer must show intermediate reasoning steps. |
| Return a JSON array only.""", |
|
|
| "numerical": """Generate {n} QA pairs involving numerical calculations or statistics. |
| Format each as JSON: {{"context": "...", "question": "...", "answer": "..."}} |
| Include percentages, comparisons, and mathematical relationships. |
| Return a JSON array only.""", |
|
|
| "code_generation": """Generate {n} coding QA pairs. |
| Format each as JSON: {{"context": "function specification", "question": "implementation task", "answer": "working code"}} |
| Cover: Python functions, algorithms, data structures. |
| Return a JSON array only.""", |
|
|
| "summarization": """Generate {n} summarization QA pairs. |
| Format each as JSON: {{"context": "long passage", "question": "summarize this", "answer": "concise summary"}} |
| Vary context length 200-800 words. |
| Return a JSON array only.""", |
| } |
|
|
|
|
| @dataclass |
| class SyntheticSample: |
| context: str |
| question: str |
| answer: str |
| gap_cat: str |
|
|
|
|
| def _parse_json_array(text: str) -> list[dict]: |
| """Extract first valid JSON array from LLM output.""" |
| |
| try: |
| return json.loads(text) |
| except json.JSONDecodeError: |
| pass |
| |
| match = re.search(r"\[.*\]", text, re.DOTALL) |
| if match: |
| try: |
| return json.loads(match.group()) |
| except json.JSONDecodeError: |
| pass |
| return [] |
|
|
|
|
| def generate_synthetic_data( |
| gap_categories: list[str], |
| n_per_gap: int = 50, |
| generator_model: str = FT_BASE_MODEL, |
| ) -> list[SyntheticSample]: |
| """ |
| Use a capable LLM to generate targeted synthetic training data |
| for each detected knowledge gap. |
| """ |
| from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig |
|
|
| logger.info(f"[SynData] Generating data for gaps: {gap_categories}") |
| bnb = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16) |
| tok = AutoTokenizer.from_pretrained(generator_model, token=HF_TOKEN or None, trust_remote_code=True) |
| if tok.pad_token is None: |
| tok.pad_token = tok.eos_token |
|
|
| model = AutoModelForCausalLM.from_pretrained( |
| generator_model, |
| quantization_config=bnb, |
| device_map="auto", |
| token=HF_TOKEN or None, |
| trust_remote_code=True, |
| ) |
| model.eval() |
|
|
| all_samples: list[SyntheticSample] = [] |
|
|
| for gap in gap_categories: |
| prompt_template = GAP_PROMPTS.get(gap, GAP_PROMPTS["factual_recall"]) |
| prompt = prompt_template.format(n=n_per_gap) |
|
|
| logger.info(f"[SynData] Generating {n_per_gap} samples for: {gap}") |
| inputs = tok(prompt, return_tensors="pt").to(model.device) |
| with torch.no_grad(): |
| out_ids = model.generate( |
| **inputs, |
| max_new_tokens=2048, |
| do_sample=True, |
| temperature=0.8, |
| top_p=0.95, |
| ) |
| new_ids = out_ids[0][inputs["input_ids"].shape[1]:] |
| raw = tok.decode(new_ids, skip_special_tokens=True) |
| items = _parse_json_array(raw) |
|
|
| for item in items: |
| try: |
| all_samples.append(SyntheticSample( |
| context = str(item.get("context", "")), |
| question = str(item.get("question", "")), |
| answer = str(item.get("answer", "")), |
| gap_cat = gap, |
| )) |
| except Exception: |
| continue |
|
|
| logger.info(f"[SynData] Got {len(items)} valid samples for '{gap}'") |
|
|
| del model |
| torch.cuda.empty_cache() |
|
|
| |
| out_path = DATA_DIR / "synthetic_data.jsonl" |
| with open(out_path, "w") as f: |
| for s in all_samples: |
| f.write(json.dumps({"context": s.context, "question": s.question, |
| "answer": s.answer, "gap_cat": s.gap_cat}) + "\n") |
| logger.success(f"[SynData] {len(all_samples)} samples saved β {out_path}") |
| return all_samples |
|
|
|
|
| |
| |
| |
|
|
| CHAT_TEMPLATE = """<s>[INST] Context: {context} |
| |
| Question: {question} [/INST] {answer}</s>""" |
|
|
|
|
| def format_as_hf_dataset(samples: list[SyntheticSample]): |
| """Convert SyntheticSample list β HF Dataset with text column.""" |
| from datasets import Dataset |
| rows = [{ |
| "text": CHAT_TEMPLATE.format( |
| context=s.context, question=s.question, answer=s.answer |
| ), |
| "gap_cat": s.gap_cat, |
| } for s in samples if s.context and s.question and s.answer] |
| return Dataset.from_list(rows) |
|
|
|
|
| def load_jsonl_dataset(path: str): |
| """Load saved synthetic JSONL as HF Dataset.""" |
| from datasets import load_dataset |
| return load_dataset("json", data_files=path, split="train") |
|
|
|
|
| |
| |
| |
|
|
| def build_lora_config( |
| r: int = CFG["lora_r"], |
| alpha: int = CFG["lora_alpha"], |
| dropout: float = CFG["lora_dropout"], |
| target_modules: Optional[list[str]] = None, |
| ): |
| """ |
| Build LoraConfig. Target modules auto-detected for common architectures |
| (Mistral, LLaMA, Qwen), or pass custom list. |
| """ |
| from peft import LoraConfig, TaskType |
|
|
| if target_modules is None: |
| |
| target_modules = [ |
| "q_proj", "k_proj", "v_proj", "o_proj", |
| "gate_proj", "up_proj", "down_proj", |
| ] |
|
|
| return LoraConfig( |
| task_type = TaskType.CAUSAL_LM, |
| r = r, |
| lora_alpha = alpha, |
| lora_dropout = dropout, |
| target_modules = target_modules, |
| bias = "none", |
| inference_mode = False, |
| ) |
|
|
|
|
| def build_bnb_config(): |
| from transformers import BitsAndBytesConfig |
| return BitsAndBytesConfig( |
| load_in_4bit = CFG["load_in_4bit"], |
| bnb_4bit_quant_type = "nf4", |
| bnb_4bit_compute_dtype = torch.bfloat16, |
| bnb_4bit_use_double_quant = True, |
| ) |
|
|
|
|
| |
| |
| |
|
|
| def fine_tune( |
| base_model_id: str, |
| dataset, |
| output_dir: Path, |
| run_name: str = "lora-ft", |
| epochs: int = FT_EPOCHS, |
| lr: float = FT_LR, |
| lora_config = None, |
| use_wandb: bool = False, |
| ) -> Path: |
| """ |
| QLoRA fine-tuning with TRL SFTTrainer. |
| Saves merged adapter + tokenizer to output_dir. |
| """ |
| from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments |
| from peft import get_peft_model, prepare_model_for_kbit_training |
| from trl import SFTTrainer, DataCollatorForCompletionOnlyLM |
|
|
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| logger.info(f"[FT] Loading base model: {base_model_id}") |
| bnb = build_bnb_config() |
| tok = AutoTokenizer.from_pretrained(base_model_id, token=HF_TOKEN or None, trust_remote_code=True) |
| if tok.pad_token is None: |
| tok.pad_token = tok.eos_token |
| tok.padding_side = "right" |
|
|
| model = AutoModelForCausalLM.from_pretrained( |
| base_model_id, |
| quantization_config = bnb, |
| device_map = "auto", |
| token = HF_TOKEN or None, |
| trust_remote_code = True, |
| torch_dtype = torch.bfloat16, |
| use_cache = False, |
| ) |
| model = prepare_model_for_kbit_training(model) |
|
|
| if lora_config is None: |
| lora_config = build_lora_config() |
|
|
| model = get_peft_model(model, lora_config) |
| model.print_trainable_parameters() |
|
|
| |
| training_args = TrainingArguments( |
| output_dir = str(output_dir), |
| num_train_epochs = epochs, |
| per_device_train_batch_size = CFG["per_device_train_batch_size"], |
| gradient_accumulation_steps = CFG["gradient_accumulation_steps"], |
| gradient_checkpointing = True, |
| optim = "paged_adamw_32bit", |
| learning_rate = lr, |
| weight_decay = 0.001, |
| warmup_ratio = FT_WARMUP_RATIO, |
| lr_scheduler_type = "cosine", |
| fp16 = False, |
| bf16 = True, |
| logging_steps = 10, |
| save_steps = FT_SAVE_STEPS, |
| save_total_limit = 2, |
| report_to = "wandb" if use_wandb else "none", |
| run_name = run_name, |
| dataloader_num_workers = 4, |
| group_by_length = True, |
| ) |
|
|
| |
| response_template = " [/INST]" |
| collator = DataCollatorForCompletionOnlyLM( |
| response_template=response_template, tokenizer=tok |
| ) |
|
|
| trainer = SFTTrainer( |
| model = model, |
| train_dataset = dataset, |
| args = training_args, |
| tokenizer = tok, |
| data_collator = collator, |
| dataset_text_field = "text", |
| max_seq_length = CFG["max_seq_length"], |
| packing = True, |
| ) |
|
|
| logger.info("[FT] Starting training...") |
| trainer.train() |
| trainer.save_model(str(output_dir / "adapter")) |
| tok.save_pretrained(str(output_dir / "adapter")) |
|
|
| logger.success(f"[FT] Adapter saved β {output_dir / 'adapter'}") |
| return output_dir / "adapter" |
|
|
|
|
| |
| |
| |
|
|
| def extract_delta_adapter( |
| base_model_id: str, |
| finetuned_path: str, |
| output_dir: Path, |
| ) -> Path: |
| """ |
| Extract the LoRA delta weights as a standalone adapter. |
| These can be merged back into any compatible base using mergekit. |
| """ |
| from peft import PeftModel |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| logger.info("[Delta] Loading base model for delta extraction...") |
| base = AutoModelForCausalLM.from_pretrained( |
| base_model_id, device_map="cpu", torch_dtype=torch.float16, |
| token=HF_TOKEN or None, trust_remote_code=True, |
| ) |
| peft_model = PeftModel.from_pretrained(base, finetuned_path) |
|
|
| logger.info("[Delta] Merging and unloading LoRA weights...") |
| merged = peft_model.merge_and_unload() |
|
|
| output_dir.mkdir(parents=True, exist_ok=True) |
| merged.save_pretrained(str(output_dir)) |
| AutoTokenizer.from_pretrained(base_model_id, token=HF_TOKEN or None).save_pretrained(str(output_dir)) |
| logger.success(f"[Delta] Merged adapter model saved β {output_dir}") |
| return output_dir |
|
|
|
|
| |
| |
| |
|
|
| def improvement_loop( |
| base_model_id: str, |
| eval_samples_fn, |
| max_iterations: int = 3, |
| target_rouge: float = 0.45, |
| n_syn_per_gap: int = 50, |
| use_wandb: bool = False, |
| ) -> Path: |
| """ |
| The core improvement loop: |
| eval β detect gaps β generate data β fine-tune β eval β repeat |
| |
| Returns path to the final best adapter. |
| """ |
| from phase3_evaluation.evaluate import evaluate, EvalResult |
|
|
| best_model = base_model_id |
| best_rouge = 0.0 |
| best_adapter: Optional[Path] = None |
| history: list[EvalResult] = [] |
|
|
| for iteration in range(1, max_iterations + 1): |
| logger.info(f"\n{'='*60}\n[Loop] Iteration {iteration}/{max_iterations}\n{'='*60}") |
|
|
| |
| samples = eval_samples_fn() |
| result = evaluate(best_model, samples, f"iter_{iteration}", run_judge=False) |
| history.append(result) |
|
|
| current_rouge = result.avg_rouge1 |
| logger.info(f"[Loop] ROUGE-1: {current_rouge:.3f} | gaps: {result.gap_categories}") |
|
|
| |
| if current_rouge >= target_rouge: |
| logger.success(f"[Loop] Target reached ({current_rouge:.3f} β₯ {target_rouge}). Stopping.") |
| break |
| if not result.gap_categories: |
| logger.info("[Loop] No gaps detected. Stopping.") |
| break |
|
|
| |
| syn_samples = generate_synthetic_data(result.gap_categories, n_per_gap=n_syn_per_gap) |
| if not syn_samples: |
| logger.warning("[Loop] No synthetic samples generated. Stopping.") |
| break |
|
|
| dataset = format_as_hf_dataset(syn_samples) |
|
|
| |
| adapter_dir = ADAPTERS_DIR / f"iter_{iteration}" |
| fine_tune( |
| base_model_id = best_model, |
| dataset = dataset, |
| output_dir = adapter_dir, |
| run_name = f"iter-{iteration}", |
| use_wandb = use_wandb, |
| ) |
|
|
| |
| merged_dir = adapter_dir / "merged" |
| best_model = str(extract_delta_adapter(best_model, str(adapter_dir / "adapter"), merged_dir)) |
| best_rouge = current_rouge |
| best_adapter= adapter_dir / "adapter" |
|
|
| logger.info(f"[Loop] Iteration {iteration} done. Next model: {best_model}") |
|
|
| logger.success(f"[Loop] Finished. Best ROUGE-1: {max(r.avg_rouge1 for r in history):.3f}") |
| return best_adapter or Path(base_model_id) |
|
|
|
|
| |
| |
| |
|
|
| @app.command() |
| def run( |
| base: str = typer.Option(FT_BASE_MODEL, "--base", help="Base model ID"), |
| gaps: list[str] = typer.Option([], "--gap", "-g", help="Gap categories (repeat flag)"), |
| data_path: str = typer.Option(None, help="Existing synthetic JSONL path"), |
| n_syn: int = typer.Option(50, help="Synthetic samples per gap"), |
| output: Path = typer.Option(ADAPTERS_DIR / "run", "--output", "-o"), |
| epochs: int = typer.Option(FT_EPOCHS), |
| loop: bool = typer.Option(False, "--loop", help="Run iterative improvement loop"), |
| max_iter: int = typer.Option(3, help="Max iterations (--loop mode)"), |
| wandb: bool = typer.Option(False, "--wandb"), |
| ): |
| if loop: |
| from phase3_evaluation.evaluate import load_squad |
| improvement_loop( |
| base_model_id = base, |
| eval_samples_fn = lambda: load_squad(50), |
| max_iterations = max_iter, |
| use_wandb = wandb, |
| ) |
| return |
|
|
| |
| if data_path: |
| dataset = load_jsonl_dataset(data_path) |
| elif gaps: |
| syn = generate_synthetic_data(list(gaps), n_per_gap=n_syn) |
| dataset = format_as_hf_dataset(syn) |
| else: |
| raise typer.BadParameter("Provide --gap or --data-path") |
|
|
| fine_tune(base, dataset, output, epochs=epochs, use_wandb=wandb) |
|
|
|
|
| if __name__ == "__main__": |
| app() |
|
|