#!/usr/bin/env python3 """Resource-aware LoRA/QLoRA starter training for BlitzKode. This script is intentionally conservative for local GPUs. It creates a small coding-instruction dataset, fine-tunes a Qwen-family causal LM with LoRA, and saves an adapter checkpoint that can later be merged/exported to GGUF. """ from __future__ import annotations import argparse import json import os from pathlib import Path from typing import Any import torch from datasets import Dataset from peft import LoraConfig, TaskType, get_peft_model from transformers import ( AutoModelForCausalLM, AutoTokenizer, DataCollatorForLanguageModeling, Trainer, TrainingArguments, ) REPO_ROOT = Path(__file__).resolve().parents[1] DEFAULT_MODEL = "Qwen/Qwen2.5-1.5B-Instruct" DEFAULT_OUTPUT = REPO_ROOT / "checkpoints" / "available-lora" DEFAULT_DATASET = REPO_ROOT / "datasets" / "raw" / "available_training.jsonl" SYSTEM_PROMPT = ( "You are BlitzKode, a precise coding assistant. Understand the task, avoid hallucinating APIs or files, " "and provide correct, efficient, tested code when possible." ) LOCAL_SAMPLES = [ { "instruction": "Write a Python function two_sum(nums, target) that returns the indices of two numbers adding to target.", "response": """Use a hash map so each value is visited once. ```python def two_sum(nums, target): seen = {} for index, value in enumerate(nums): needed = target - value if needed in seen: return [seen[needed], index] seen[value] = index return [] ``` Time: O(n). Space: O(n).""", }, { "instruction": "Explain how to debug a FastAPI endpoint returning 422 validation errors.", "response": """A 422 usually means the request does not match the Pydantic model or endpoint signature. Check the JSON body, field names, types, required fields, and whether parameters are expected in the query instead of the body. Reproduce with curl, inspect `response.json()`, and add a focused test for the failing payload.""", }, { "instruction": "Refactor this Python pattern: opening a file without a context manager.", "response": """Use a context manager so the file is closed even if an exception is raised. ```python from pathlib import Path def read_text_file(path: str | Path) -> str: with Path(path).open(encoding="utf-8") as handle: return handle.read() ```""", }, { "instruction": "Write a JavaScript debounce function and explain when to use it.", "response": """```javascript function debounce(fn, delayMs) { let timerId; return function debounced(...args) { clearTimeout(timerId); timerId = setTimeout(() => fn.apply(this, args), delayMs); }; } ``` Use debounce when many rapid events should produce one action, such as search input, resize handlers, or autosave triggers.""", }, { "instruction": "Give a safe plan for modifying an unfamiliar codebase.", "response": """1. Read the README, dependency files, tests, and main entry points. 2. Run the existing test suite before editing. 3. Make the smallest coherent change. 4. Add or update tests around the behavior. 5. Re-run lint, type checks, and tests. 6. Document assumptions and any follow-up risks.""", }, { "instruction": "Implement binary search in Python with edge-case handling.", "response": """```python def binary_search(values, target): left, right = 0, len(values) - 1 while left <= right: mid = left + (right - left) // 2 if values[mid] == target: return mid if values[mid] < target: left = mid + 1 else: right = mid - 1 return -1 ``` The input must already be sorted. Empty lists return `-1`.""", }, ] def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description=__doc__) parser.add_argument("--model", default=DEFAULT_MODEL, help="Base model ID or local path.") parser.add_argument("--output-dir", type=Path, default=DEFAULT_OUTPUT, help="Where the LoRA adapter will be saved.") parser.add_argument("--dataset", type=Path, default=DEFAULT_DATASET, help="Optional JSON or JSONL dataset with instruction/response fields.") parser.add_argument("--max-steps", type=int, default=10, help="Training steps. Keep small for a first local run.") parser.add_argument("--seq-len", type=int, default=512, help="Token sequence length.") parser.add_argument("--batch-size", type=int, default=1, help="Per-device batch size.") parser.add_argument("--grad-accum", type=int, default=4, help="Gradient accumulation steps.") parser.add_argument("--learning-rate", type=float, default=2e-4, help="LoRA learning rate.") parser.add_argument("--lora-r", type=int, default=16, help="LoRA rank.") parser.add_argument("--lora-alpha", type=int, default=32, help="LoRA alpha.") parser.add_argument("--quantization", choices=("auto", "4bit", "none"), default="auto", help="Use 4-bit QLoRA when available.") parser.add_argument("--sample-limit", type=int, default=32, help="Maximum training samples loaded.") return parser.parse_args() def gpu_summary() -> str: if not torch.cuda.is_available(): return "CUDA unavailable; training will use CPU and be slow." parts = [] for index in range(torch.cuda.device_count()): props = torch.cuda.get_device_properties(index) parts.append(f"GPU {index}: {props.name}, {props.total_memory / 1024**3:.1f} GB VRAM") return "; ".join(parts) def ensure_dataset(path: Path, sample_limit: int) -> list[dict[str, str]]: path.parent.mkdir(parents=True, exist_ok=True) if not path.exists(): with path.open("w", encoding="utf-8") as handle: for sample in LOCAL_SAMPLES: handle.write(json.dumps(sample, ensure_ascii=False) + "\n") raw_text = path.read_text(encoding="utf-8").strip() if not raw_text: raise SystemExit(f"Dataset is empty: {path}") if raw_text.startswith("["): rows = json.loads(raw_text) else: rows = [json.loads(line) for line in raw_text.splitlines() if line.strip()] samples: list[dict[str, str]] = [] for item in rows: instruction = str(item.get("instruction") or item.get("prompt") or "").strip() response = str(item.get("response") or item.get("output") or "").strip() if instruction and response: samples.append({"instruction": instruction, "response": response}) if len(samples) >= sample_limit: break if not samples: raise SystemExit(f"No usable samples found in {path}") return samples def format_sample(sample: dict[str, str]) -> str: return ( f"<|im_start|>system\n{SYSTEM_PROMPT}<|im_end|>\n" f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n" f"<|im_start|>assistant\n{sample['response']}<|im_end|>" ) def load_model(model_name: str, quantization: str) -> tuple[Any, Any, bool]: tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token use_4bit = quantization == "4bit" or (quantization == "auto" and torch.cuda.is_available()) quantization_config = None if use_4bit: try: from transformers import BitsAndBytesConfig # noqa: PLC0415 quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16, bnb_4bit_use_double_quant=True, ) except Exception as exc: print(f"[WARN] 4-bit config unavailable, falling back to 16-bit load: {exc}") use_4bit = False dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16 model_kwargs: dict[str, Any] = { "trust_remote_code": True, "device_map": "auto" if torch.cuda.is_available() else None, "dtype": dtype, } if quantization_config is not None: del model_kwargs["dtype"] model_kwargs["quantization_config"] = quantization_config try: model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs) except Exception: if quantization_config is None: raise print("[WARN] 4-bit model load failed; retrying with 16-bit LoRA.") use_4bit = False model = AutoModelForCausalLM.from_pretrained( model_name, dtype=dtype, device_map="auto" if torch.cuda.is_available() else None, trust_remote_code=True, ) if use_4bit: from peft import prepare_model_for_kbit_training # noqa: PLC0415 model = prepare_model_for_kbit_training(model) elif hasattr(model, "gradient_checkpointing_enable"): model.gradient_checkpointing_enable() return model, tokenizer, use_4bit def main() -> None: args = parse_args() os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") print("=" * 72) print("BLITZKODE RESOURCE-AWARE TRAINING START") print("=" * 72) print(gpu_summary()) print(f"Base model: {args.model}") print(f"Output: {args.output_dir}") print(f"Steps: {args.max_steps}") samples = ensure_dataset(args.dataset, args.sample_limit) texts = [format_sample(sample) for sample in samples] dataset = Dataset.from_dict({"text": texts}) print(f"Dataset: {len(dataset)} samples from {args.dataset}") model, tokenizer, use_4bit = load_model(args.model, args.quantization) model.config.use_cache = False lora_config = LoraConfig( r=args.lora_r, lora_alpha=args.lora_alpha, lora_dropout=0.05, bias="none", task_type=TaskType.CAUSAL_LM, target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], ) model = get_peft_model(model, lora_config) model.print_trainable_parameters() def tokenize(batch: dict[str, list[str]]) -> dict[str, Any]: return tokenizer(batch["text"], truncation=True, max_length=args.seq_len, padding="max_length") tokenized = dataset.map(tokenize, batched=True, remove_columns=["text"]) collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) args.output_dir.mkdir(parents=True, exist_ok=True) training_args = TrainingArguments( output_dir=str(args.output_dir), max_steps=args.max_steps, per_device_train_batch_size=args.batch_size, gradient_accumulation_steps=args.grad_accum, learning_rate=args.learning_rate, warmup_steps=0, logging_steps=1, save_steps=max(1, args.max_steps), save_total_limit=2, report_to="none", remove_unused_columns=False, fp16=torch.cuda.is_available() and not torch.cuda.is_bf16_supported(), bf16=torch.cuda.is_available() and torch.cuda.is_bf16_supported(), gradient_checkpointing=True, optim="paged_adamw_8bit" if use_4bit else "adamw_torch", ) trainer = Trainer(model=model, args=training_args, train_dataset=tokenized, data_collator=collator) train_result = trainer.train() final_dir = args.output_dir / "final" trainer.save_model(str(final_dir)) tokenizer.save_pretrained(str(final_dir)) metrics = train_result.metrics metrics_path = args.output_dir / "train_metrics.json" with metrics_path.open("w", encoding="utf-8") as handle: json.dump(metrics, handle, indent=2) print("=" * 72) print("TRAINING COMPLETE") print(f"Adapter saved to: {final_dir}") print(f"Metrics saved to: {metrics_path}") print("Next: run a longer training job or merge/export with scripts/export_gguf.py") if __name__ == "__main__": main()