blitzkode / scripts /train_available.py
neuralbroker's picture
Add scripts/train_available.py
31c9421 verified
raw
history blame
12.4 kB
#!/usr/bin/env python3
"""Resource-aware LoRA/QLoRA starter training for BlitzKode.
This script is intentionally conservative for local GPUs. It creates a small
coding-instruction dataset, fine-tunes a Qwen-family causal LM with LoRA, and
saves an adapter checkpoint that can later be merged/exported to GGUF.
"""
from __future__ import annotations
import argparse
import json
import os
from pathlib import Path
from typing import Any
import torch
from datasets import Dataset
from peft import LoraConfig, TaskType, get_peft_model
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
DataCollatorForLanguageModeling,
Trainer,
TrainingArguments,
)
REPO_ROOT = Path(__file__).resolve().parents[1]
DEFAULT_MODEL = "Qwen/Qwen2.5-1.5B-Instruct"
DEFAULT_OUTPUT = REPO_ROOT / "checkpoints" / "available-lora"
DEFAULT_DATASET = REPO_ROOT / "datasets" / "raw" / "available_training.jsonl"
SYSTEM_PROMPT = (
"You are BlitzKode, a precise coding assistant. Understand the task, avoid hallucinating APIs or files, "
"and provide correct, efficient, tested code when possible."
)
LOCAL_SAMPLES = [
{
"instruction": "Write a Python function two_sum(nums, target) that returns the indices of two numbers adding to target.",
"response": """Use a hash map so each value is visited once.
```python
def two_sum(nums, target):
seen = {}
for index, value in enumerate(nums):
needed = target - value
if needed in seen:
return [seen[needed], index]
seen[value] = index
return []
```
Time: O(n). Space: O(n).""",
},
{
"instruction": "Explain how to debug a FastAPI endpoint returning 422 validation errors.",
"response": """A 422 usually means the request does not match the Pydantic model or endpoint signature.
Check the JSON body, field names, types, required fields, and whether parameters are expected in the query instead of the body.
Reproduce with curl, inspect `response.json()`, and add a focused test for the failing payload.""",
},
{
"instruction": "Refactor this Python pattern: opening a file without a context manager.",
"response": """Use a context manager so the file is closed even if an exception is raised.
```python
from pathlib import Path
def read_text_file(path: str | Path) -> str:
with Path(path).open(encoding="utf-8") as handle:
return handle.read()
```""",
},
{
"instruction": "Write a JavaScript debounce function and explain when to use it.",
"response": """```javascript
function debounce(fn, delayMs) {
let timerId;
return function debounced(...args) {
clearTimeout(timerId);
timerId = setTimeout(() => fn.apply(this, args), delayMs);
};
}
```
Use debounce when many rapid events should produce one action, such as search input, resize handlers, or autosave triggers.""",
},
{
"instruction": "Give a safe plan for modifying an unfamiliar codebase.",
"response": """1. Read the README, dependency files, tests, and main entry points.
2. Run the existing test suite before editing.
3. Make the smallest coherent change.
4. Add or update tests around the behavior.
5. Re-run lint, type checks, and tests.
6. Document assumptions and any follow-up risks.""",
},
{
"instruction": "Implement binary search in Python with edge-case handling.",
"response": """```python
def binary_search(values, target):
left, right = 0, len(values) - 1
while left <= right:
mid = left + (right - left) // 2
if values[mid] == target:
return mid
if values[mid] < target:
left = mid + 1
else:
right = mid - 1
return -1
```
The input must already be sorted. Empty lists return `-1`.""",
},
]
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--model", default=DEFAULT_MODEL, help="Base model ID or local path.")
parser.add_argument("--output-dir", type=Path, default=DEFAULT_OUTPUT, help="Where the LoRA adapter will be saved.")
parser.add_argument("--dataset", type=Path, default=DEFAULT_DATASET, help="Optional JSON or JSONL dataset with instruction/response fields.")
parser.add_argument("--max-steps", type=int, default=10, help="Training steps. Keep small for a first local run.")
parser.add_argument("--seq-len", type=int, default=512, help="Token sequence length.")
parser.add_argument("--batch-size", type=int, default=1, help="Per-device batch size.")
parser.add_argument("--grad-accum", type=int, default=4, help="Gradient accumulation steps.")
parser.add_argument("--learning-rate", type=float, default=2e-4, help="LoRA learning rate.")
parser.add_argument("--lora-r", type=int, default=16, help="LoRA rank.")
parser.add_argument("--lora-alpha", type=int, default=32, help="LoRA alpha.")
parser.add_argument("--quantization", choices=("auto", "4bit", "none"), default="auto", help="Use 4-bit QLoRA when available.")
parser.add_argument("--sample-limit", type=int, default=32, help="Maximum training samples loaded.")
return parser.parse_args()
def gpu_summary() -> str:
if not torch.cuda.is_available():
return "CUDA unavailable; training will use CPU and be slow."
parts = []
for index in range(torch.cuda.device_count()):
props = torch.cuda.get_device_properties(index)
parts.append(f"GPU {index}: {props.name}, {props.total_memory / 1024**3:.1f} GB VRAM")
return "; ".join(parts)
def ensure_dataset(path: Path, sample_limit: int) -> list[dict[str, str]]:
path.parent.mkdir(parents=True, exist_ok=True)
if not path.exists():
with path.open("w", encoding="utf-8") as handle:
for sample in LOCAL_SAMPLES:
handle.write(json.dumps(sample, ensure_ascii=False) + "\n")
raw_text = path.read_text(encoding="utf-8").strip()
if not raw_text:
raise SystemExit(f"Dataset is empty: {path}")
if raw_text.startswith("["):
rows = json.loads(raw_text)
else:
rows = [json.loads(line) for line in raw_text.splitlines() if line.strip()]
samples: list[dict[str, str]] = []
for item in rows:
instruction = str(item.get("instruction") or item.get("prompt") or "").strip()
response = str(item.get("response") or item.get("output") or "").strip()
if instruction and response:
samples.append({"instruction": instruction, "response": response})
if len(samples) >= sample_limit:
break
if not samples:
raise SystemExit(f"No usable samples found in {path}")
return samples
def format_sample(sample: dict[str, str]) -> str:
return (
f"<|im_start|>system\n{SYSTEM_PROMPT}<|im_end|>\n"
f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n"
f"<|im_start|>assistant\n{sample['response']}<|im_end|>"
)
def load_model(model_name: str, quantization: str) -> tuple[Any, Any, bool]:
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
use_4bit = quantization == "4bit" or (quantization == "auto" and torch.cuda.is_available())
quantization_config = None
if use_4bit:
try:
from transformers import BitsAndBytesConfig # noqa: PLC0415
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
bnb_4bit_use_double_quant=True,
)
except Exception as exc:
print(f"[WARN] 4-bit config unavailable, falling back to 16-bit load: {exc}")
use_4bit = False
dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
model_kwargs: dict[str, Any] = {
"trust_remote_code": True,
"device_map": "auto" if torch.cuda.is_available() else None,
"dtype": dtype,
}
if quantization_config is not None:
del model_kwargs["dtype"]
model_kwargs["quantization_config"] = quantization_config
try:
model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
except Exception:
if quantization_config is None:
raise
print("[WARN] 4-bit model load failed; retrying with 16-bit LoRA.")
use_4bit = False
model = AutoModelForCausalLM.from_pretrained(
model_name,
dtype=dtype,
device_map="auto" if torch.cuda.is_available() else None,
trust_remote_code=True,
)
if use_4bit:
from peft import prepare_model_for_kbit_training # noqa: PLC0415
model = prepare_model_for_kbit_training(model)
elif hasattr(model, "gradient_checkpointing_enable"):
model.gradient_checkpointing_enable()
return model, tokenizer, use_4bit
def main() -> None:
args = parse_args()
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
print("=" * 72)
print("BLITZKODE RESOURCE-AWARE TRAINING START")
print("=" * 72)
print(gpu_summary())
print(f"Base model: {args.model}")
print(f"Output: {args.output_dir}")
print(f"Steps: {args.max_steps}")
samples = ensure_dataset(args.dataset, args.sample_limit)
texts = [format_sample(sample) for sample in samples]
dataset = Dataset.from_dict({"text": texts})
print(f"Dataset: {len(dataset)} samples from {args.dataset}")
model, tokenizer, use_4bit = load_model(args.model, args.quantization)
model.config.use_cache = False
lora_config = LoraConfig(
r=args.lora_r,
lora_alpha=args.lora_alpha,
lora_dropout=0.05,
bias="none",
task_type=TaskType.CAUSAL_LM,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
def tokenize(batch: dict[str, list[str]]) -> dict[str, Any]:
return tokenizer(batch["text"], truncation=True, max_length=args.seq_len, padding="max_length")
tokenized = dataset.map(tokenize, batched=True, remove_columns=["text"])
collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
args.output_dir.mkdir(parents=True, exist_ok=True)
training_args = TrainingArguments(
output_dir=str(args.output_dir),
max_steps=args.max_steps,
per_device_train_batch_size=args.batch_size,
gradient_accumulation_steps=args.grad_accum,
learning_rate=args.learning_rate,
warmup_steps=0,
logging_steps=1,
save_steps=max(1, args.max_steps),
save_total_limit=2,
report_to="none",
remove_unused_columns=False,
fp16=torch.cuda.is_available() and not torch.cuda.is_bf16_supported(),
bf16=torch.cuda.is_available() and torch.cuda.is_bf16_supported(),
gradient_checkpointing=True,
optim="paged_adamw_8bit" if use_4bit else "adamw_torch",
)
trainer = Trainer(model=model, args=training_args, train_dataset=tokenized, data_collator=collator)
train_result = trainer.train()
final_dir = args.output_dir / "final"
trainer.save_model(str(final_dir))
tokenizer.save_pretrained(str(final_dir))
metrics = train_result.metrics
metrics_path = args.output_dir / "train_metrics.json"
with metrics_path.open("w", encoding="utf-8") as handle:
json.dump(metrics, handle, indent=2)
print("=" * 72)
print("TRAINING COMPLETE")
print(f"Adapter saved to: {final_dir}")
print(f"Metrics saved to: {metrics_path}")
print("Next: run a longer training job or merge/export with scripts/export_gguf.py")
if __name__ == "__main__":
main()