File size: 12,356 Bytes
31c9421
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
#!/usr/bin/env python3
"""Resource-aware LoRA/QLoRA starter training for BlitzKode.



This script is intentionally conservative for local GPUs. It creates a small

coding-instruction dataset, fine-tunes a Qwen-family causal LM with LoRA, and

saves an adapter checkpoint that can later be merged/exported to GGUF.

"""

from __future__ import annotations

import argparse
import json
import os
from pathlib import Path
from typing import Any

import torch
from datasets import Dataset
from peft import LoraConfig, TaskType, get_peft_model
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    DataCollatorForLanguageModeling,
    Trainer,
    TrainingArguments,
)


REPO_ROOT = Path(__file__).resolve().parents[1]
DEFAULT_MODEL = "Qwen/Qwen2.5-1.5B-Instruct"
DEFAULT_OUTPUT = REPO_ROOT / "checkpoints" / "available-lora"
DEFAULT_DATASET = REPO_ROOT / "datasets" / "raw" / "available_training.jsonl"

SYSTEM_PROMPT = (
    "You are BlitzKode, a precise coding assistant. Understand the task, avoid hallucinating APIs or files, "
    "and provide correct, efficient, tested code when possible."
)

LOCAL_SAMPLES = [
    {
        "instruction": "Write a Python function two_sum(nums, target) that returns the indices of two numbers adding to target.",
        "response": """Use a hash map so each value is visited once.



```python

def two_sum(nums, target):

    seen = {}

    for index, value in enumerate(nums):

        needed = target - value

        if needed in seen:

            return [seen[needed], index]

        seen[value] = index

    return []

```



Time: O(n). Space: O(n).""",
    },
    {
        "instruction": "Explain how to debug a FastAPI endpoint returning 422 validation errors.",
        "response": """A 422 usually means the request does not match the Pydantic model or endpoint signature.

Check the JSON body, field names, types, required fields, and whether parameters are expected in the query instead of the body.

Reproduce with curl, inspect `response.json()`, and add a focused test for the failing payload.""",
    },
    {
        "instruction": "Refactor this Python pattern: opening a file without a context manager.",
        "response": """Use a context manager so the file is closed even if an exception is raised.



```python

from pathlib import Path



def read_text_file(path: str | Path) -> str:

    with Path(path).open(encoding="utf-8") as handle:

        return handle.read()

```""",
    },
    {
        "instruction": "Write a JavaScript debounce function and explain when to use it.",
        "response": """```javascript

function debounce(fn, delayMs) {

  let timerId;

  return function debounced(...args) {

    clearTimeout(timerId);

    timerId = setTimeout(() => fn.apply(this, args), delayMs);

  };

}

```



Use debounce when many rapid events should produce one action, such as search input, resize handlers, or autosave triggers.""",
    },
    {
        "instruction": "Give a safe plan for modifying an unfamiliar codebase.",
        "response": """1. Read the README, dependency files, tests, and main entry points.

2. Run the existing test suite before editing.

3. Make the smallest coherent change.

4. Add or update tests around the behavior.

5. Re-run lint, type checks, and tests.

6. Document assumptions and any follow-up risks.""",
    },
    {
        "instruction": "Implement binary search in Python with edge-case handling.",
        "response": """```python

def binary_search(values, target):

    left, right = 0, len(values) - 1

    while left <= right:

        mid = left + (right - left) // 2

        if values[mid] == target:

            return mid

        if values[mid] < target:

            left = mid + 1

        else:

            right = mid - 1

    return -1

```



The input must already be sorted. Empty lists return `-1`.""",
    },
]


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--model", default=DEFAULT_MODEL, help="Base model ID or local path.")
    parser.add_argument("--output-dir", type=Path, default=DEFAULT_OUTPUT, help="Where the LoRA adapter will be saved.")
    parser.add_argument("--dataset", type=Path, default=DEFAULT_DATASET, help="Optional JSON or JSONL dataset with instruction/response fields.")
    parser.add_argument("--max-steps", type=int, default=10, help="Training steps. Keep small for a first local run.")
    parser.add_argument("--seq-len", type=int, default=512, help="Token sequence length.")
    parser.add_argument("--batch-size", type=int, default=1, help="Per-device batch size.")
    parser.add_argument("--grad-accum", type=int, default=4, help="Gradient accumulation steps.")
    parser.add_argument("--learning-rate", type=float, default=2e-4, help="LoRA learning rate.")
    parser.add_argument("--lora-r", type=int, default=16, help="LoRA rank.")
    parser.add_argument("--lora-alpha", type=int, default=32, help="LoRA alpha.")
    parser.add_argument("--quantization", choices=("auto", "4bit", "none"), default="auto", help="Use 4-bit QLoRA when available.")
    parser.add_argument("--sample-limit", type=int, default=32, help="Maximum training samples loaded.")
    return parser.parse_args()


def gpu_summary() -> str:
    if not torch.cuda.is_available():
        return "CUDA unavailable; training will use CPU and be slow."
    parts = []
    for index in range(torch.cuda.device_count()):
        props = torch.cuda.get_device_properties(index)
        parts.append(f"GPU {index}: {props.name}, {props.total_memory / 1024**3:.1f} GB VRAM")
    return "; ".join(parts)


def ensure_dataset(path: Path, sample_limit: int) -> list[dict[str, str]]:
    path.parent.mkdir(parents=True, exist_ok=True)
    if not path.exists():
        with path.open("w", encoding="utf-8") as handle:
            for sample in LOCAL_SAMPLES:
                handle.write(json.dumps(sample, ensure_ascii=False) + "\n")

    raw_text = path.read_text(encoding="utf-8").strip()
    if not raw_text:
        raise SystemExit(f"Dataset is empty: {path}")

    if raw_text.startswith("["):
        rows = json.loads(raw_text)
    else:
        rows = [json.loads(line) for line in raw_text.splitlines() if line.strip()]

    samples: list[dict[str, str]] = []
    for item in rows:
        instruction = str(item.get("instruction") or item.get("prompt") or "").strip()
        response = str(item.get("response") or item.get("output") or "").strip()
        if instruction and response:
            samples.append({"instruction": instruction, "response": response})
        if len(samples) >= sample_limit:
            break
    if not samples:
        raise SystemExit(f"No usable samples found in {path}")
    return samples


def format_sample(sample: dict[str, str]) -> str:
    return (
        f"<|im_start|>system\n{SYSTEM_PROMPT}<|im_end|>\n"
        f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n"
        f"<|im_start|>assistant\n{sample['response']}<|im_end|>"
    )


def load_model(model_name: str, quantization: str) -> tuple[Any, Any, bool]:
    tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    use_4bit = quantization == "4bit" or (quantization == "auto" and torch.cuda.is_available())
    quantization_config = None
    if use_4bit:
        try:
            from transformers import BitsAndBytesConfig  # noqa: PLC0415

            quantization_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
                bnb_4bit_use_double_quant=True,
            )
        except Exception as exc:
            print(f"[WARN] 4-bit config unavailable, falling back to 16-bit load: {exc}")
            use_4bit = False

    dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16
    model_kwargs: dict[str, Any] = {
        "trust_remote_code": True,
        "device_map": "auto" if torch.cuda.is_available() else None,
        "dtype": dtype,
    }
    if quantization_config is not None:
        del model_kwargs["dtype"]
        model_kwargs["quantization_config"] = quantization_config

    try:
        model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
    except Exception:
        if quantization_config is None:
            raise
        print("[WARN] 4-bit model load failed; retrying with 16-bit LoRA.")
        use_4bit = False
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            dtype=dtype,
            device_map="auto" if torch.cuda.is_available() else None,
            trust_remote_code=True,
        )

    if use_4bit:
        from peft import prepare_model_for_kbit_training  # noqa: PLC0415

        model = prepare_model_for_kbit_training(model)
    elif hasattr(model, "gradient_checkpointing_enable"):
        model.gradient_checkpointing_enable()

    return model, tokenizer, use_4bit


def main() -> None:
    args = parse_args()
    os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")

    print("=" * 72)
    print("BLITZKODE RESOURCE-AWARE TRAINING START")
    print("=" * 72)
    print(gpu_summary())
    print(f"Base model: {args.model}")
    print(f"Output:     {args.output_dir}")
    print(f"Steps:      {args.max_steps}")

    samples = ensure_dataset(args.dataset, args.sample_limit)
    texts = [format_sample(sample) for sample in samples]
    dataset = Dataset.from_dict({"text": texts})
    print(f"Dataset:    {len(dataset)} samples from {args.dataset}")

    model, tokenizer, use_4bit = load_model(args.model, args.quantization)
    model.config.use_cache = False

    lora_config = LoraConfig(
        r=args.lora_r,
        lora_alpha=args.lora_alpha,
        lora_dropout=0.05,
        bias="none",
        task_type=TaskType.CAUSAL_LM,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    )
    model = get_peft_model(model, lora_config)
    model.print_trainable_parameters()

    def tokenize(batch: dict[str, list[str]]) -> dict[str, Any]:
        return tokenizer(batch["text"], truncation=True, max_length=args.seq_len, padding="max_length")

    tokenized = dataset.map(tokenize, batched=True, remove_columns=["text"])
    collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

    args.output_dir.mkdir(parents=True, exist_ok=True)
    training_args = TrainingArguments(
        output_dir=str(args.output_dir),
        max_steps=args.max_steps,
        per_device_train_batch_size=args.batch_size,
        gradient_accumulation_steps=args.grad_accum,
        learning_rate=args.learning_rate,
        warmup_steps=0,
        logging_steps=1,
        save_steps=max(1, args.max_steps),
        save_total_limit=2,
        report_to="none",
        remove_unused_columns=False,
        fp16=torch.cuda.is_available() and not torch.cuda.is_bf16_supported(),
        bf16=torch.cuda.is_available() and torch.cuda.is_bf16_supported(),
        gradient_checkpointing=True,
        optim="paged_adamw_8bit" if use_4bit else "adamw_torch",
    )

    trainer = Trainer(model=model, args=training_args, train_dataset=tokenized, data_collator=collator)
    train_result = trainer.train()

    final_dir = args.output_dir / "final"
    trainer.save_model(str(final_dir))
    tokenizer.save_pretrained(str(final_dir))

    metrics = train_result.metrics
    metrics_path = args.output_dir / "train_metrics.json"
    with metrics_path.open("w", encoding="utf-8") as handle:
        json.dump(metrics, handle, indent=2)

    print("=" * 72)
    print("TRAINING COMPLETE")
    print(f"Adapter saved to: {final_dir}")
    print(f"Metrics saved to: {metrics_path}")
    print("Next: run a longer training job or merge/export with scripts/export_gguf.py")


if __name__ == "__main__":
    main()