Adiuk's picture
Upload train.py with huggingface_hub
3aaafb8 verified
import json, os, torch, sys
TOKEN = sys.argv[1] if len(sys.argv) > 1 else os.environ.get("HF_TOKEN", "")
if not TOKEN:
print("Usage: python3 train.py <HF_TOKEN> OR set HF_TOKEN env var")
sys.exit(1)
# Ensure we have the right transformers version with native BitNet
os.system("pip install -q 'transformers>=4.52' peft")
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
from peft import LoraConfig, get_peft_model, TaskType
from huggingface_hub import login, hf_hub_download
login(token=TOKEN)
os.makedirs("/workspace/data", exist_ok=True)
hf_hub_download("Adiuk/adis-tool-calling-dataset", "train.jsonl", local_dir="/workspace/data", repo_type="dataset")
M = "microsoft/BitNet-b1.58-2B-4T"
print(f"GPU: {torch.cuda.get_device_name(0)}, VRAM: {torch.cuda.get_device_properties(0).total_memory/1e9:.1f}GB")
print(f"Loading {M}...")
tok = AutoTokenizer.from_pretrained(M)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
# Try native loading first (transformers 4.52+), fallback to trust_remote_code
try:
mdl = AutoModelForCausalLM.from_pretrained(M, torch_dtype=torch.bfloat16, device_map="auto")
except Exception:
mdl = AutoModelForCausalLM.from_pretrained(M, trust_remote_code=True, torch_dtype=torch.bfloat16, device_map="auto")
print(f"Loaded: {mdl.num_parameters()/1e6:.0f}M params")
tgt = list(set(n.split(".")[-1] for n, m in mdl.named_modules() if isinstance(m, torch.nn.Linear) and any(k in n for k in ("q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj")))) or ["q_proj","v_proj"]
print(f"LoRA targets: {tgt}")
mdl = get_peft_model(mdl, LoraConfig(r=16, lora_alpha=32, lora_dropout=0.05, target_modules=tgt, task_type=TaskType.CAUSAL_LM))
mdl.print_trainable_parameters()
ex = [json.loads(l) for l in open("/workspace/data/train.jsonl")]
enc = tok([e["text"] for e in ex], truncation=True, max_length=512, padding="max_length", return_tensors="pt")
class D(torch.utils.data.Dataset):
def __init__(s,e): s.e=e
def __len__(s): return len(s.e["input_ids"])
def __getitem__(s,i): return {"input_ids":s.e["input_ids"][i],"attention_mask":s.e["attention_mask"][i],"labels":s.e["input_ids"][i].clone()}
print(f"Training {len(ex)} examples...")
Trainer(model=mdl, args=TrainingArguments(output_dir="/workspace/out", num_train_epochs=5, per_device_train_batch_size=4, gradient_accumulation_steps=4, learning_rate=2e-4, lr_scheduler_type="cosine", warmup_steps=10, logging_steps=5, bf16=True, report_to="none", remove_unused_columns=False), train_dataset=D(enc)).train()
print("Merging LoRA...")
mg = mdl.merge_and_unload()
mg.save_pretrained("/workspace/out/merged")
tok.save_pretrained("/workspace/out/merged")
print("Pushing to HuggingFace...")
mg.push_to_hub("Adiuk/bitnet-adis-tool-calling", token=TOKEN)
tok.push_to_hub("Adiuk/bitnet-adis-tool-calling", token=TOKEN)
print("DONE! Model at https://huggingface.co/Adiuk/bitnet-adis-tool-calling")