| import json, os, torch, sys |
| TOKEN = sys.argv[1] if len(sys.argv) > 1 else os.environ.get("HF_TOKEN", "") |
| if not TOKEN: |
| print("Usage: python3 train.py <HF_TOKEN> OR set HF_TOKEN env var") |
| sys.exit(1) |
|
|
| |
| os.system("pip install -q 'transformers>=4.52' peft") |
|
|
| from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer |
| from peft import LoraConfig, get_peft_model, TaskType |
| from huggingface_hub import login, hf_hub_download |
|
|
| login(token=TOKEN) |
| os.makedirs("/workspace/data", exist_ok=True) |
| hf_hub_download("Adiuk/adis-tool-calling-dataset", "train.jsonl", local_dir="/workspace/data", repo_type="dataset") |
|
|
| M = "microsoft/BitNet-b1.58-2B-4T" |
| print(f"GPU: {torch.cuda.get_device_name(0)}, VRAM: {torch.cuda.get_device_properties(0).total_memory/1e9:.1f}GB") |
|
|
| print(f"Loading {M}...") |
| tok = AutoTokenizer.from_pretrained(M) |
| if tok.pad_token is None: |
| tok.pad_token = tok.eos_token |
|
|
| |
| try: |
| mdl = AutoModelForCausalLM.from_pretrained(M, torch_dtype=torch.bfloat16, device_map="auto") |
| except Exception: |
| mdl = AutoModelForCausalLM.from_pretrained(M, trust_remote_code=True, torch_dtype=torch.bfloat16, device_map="auto") |
|
|
| print(f"Loaded: {mdl.num_parameters()/1e6:.0f}M params") |
|
|
| tgt = list(set(n.split(".")[-1] for n, m in mdl.named_modules() if isinstance(m, torch.nn.Linear) and any(k in n for k in ("q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj")))) or ["q_proj","v_proj"] |
| print(f"LoRA targets: {tgt}") |
| mdl = get_peft_model(mdl, LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, target_modules=tgt, task_type=TaskType.CAUSAL_LM)) |
| mdl.print_trainable_parameters() |
|
|
| ex = [json.loads(l) for l in open("/workspace/data/train.jsonl")] |
| enc = tok([e["text"] for e in ex], truncation=True, max_length=512, padding="max_length", return_tensors="pt") |
| class D(torch.utils.data.Dataset): |
| def __init__(s,e): s.e=e |
| def __len__(s): return len(s.e["input_ids"]) |
| def __getitem__(s,i): return {"input_ids":s.e["input_ids"][i],"attention_mask":s.e["attention_mask"][i],"labels":s.e["input_ids"][i].clone()} |
|
|
| print(f"Training {len(ex)} examples...") |
| Trainer(model=mdl, args=TrainingArguments(output_dir="/workspace/out", num_train_epochs=5, per_device_train_batch_size=4, gradient_accumulation_steps=4, learning_rate=2e-4, lr_scheduler_type="cosine", warmup_steps=10, logging_steps=5, bf16=True, report_to="none", remove_unused_columns=False), train_dataset=D(enc)).train() |
|
|
| print("Merging LoRA...") |
| mg = mdl.merge_and_unload() |
| mg.save_pretrained("/workspace/out/merged") |
| tok.save_pretrained("/workspace/out/merged") |
| print("Pushing to HuggingFace...") |
| mg.push_to_hub("Adiuk/bitnet-adis-tool-calling", token=TOKEN) |
| tok.push_to_hub("Adiuk/bitnet-adis-tool-calling", token=TOKEN) |
| print("DONE! Model at https://huggingface.co/Adiuk/bitnet-adis-tool-calling") |
|
|