philippotiger commited on
Commit
a987f51
·
verified ·
1 Parent(s): d932e0c

Upload train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train.py +117 -0
train.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Fine-tuning Qwen2.5-3B-Instruct for football prediction extraction
3
+ Fixes from original: target_modules, validation split, scheduler, checkpoint saving
4
+ """
5
+
6
+ from datasets import load_dataset
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, BitsAndBytesConfig
8
+ from peft import LoraConfig
9
+ from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
10
+ import torch
11
+
12
+ # ─────────────────────────────────────────────
13
+ # CONFIG
14
+ # ─────────────────────────────────────────────
15
+ MODEL_NAME = "Qwen/Qwen2.5-3B-Instruct"
16
+ OUTPUT_DIR = "./football-extractor"
17
+ TRAIN_FILE = "train_dataset.jsonl"
18
+ VAL_FILE = "val_dataset.jsonl"
19
+
20
+ # ─────────────────────────────────────────────
21
+ # LOAD DATA
22
+ # ─────────────────────────────────────────────
23
+ dataset = load_dataset("json", data_files={"train": TRAIN_FILE, "validation": VAL_FILE})
24
+ print(f"Train: {len(dataset['train'])} | Val: {len(dataset['validation'])}")
25
+
26
+ # ─────────────────────────────────────────────
27
+ # TOKENIZER
28
+ # ─────────────────────────────────────────────
29
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
30
+ tokenizer.pad_token = tokenizer.eos_token
31
+ tokenizer.padding_side = "right" # important for causal LM training
32
+
33
+ # ─────────────────────────────────────────────
34
+ # QUANTIZATION (4-bit QLoRA)
35
+ # ─────────────────────────────────────────────
36
+ bnb_config = BitsAndBytesConfig(
37
+ load_in_4bit=True,
38
+ bnb_4bit_quant_type="nf4",
39
+ bnb_4bit_compute_dtype=torch.bfloat16, # bfloat16 is more stable than float16
40
+ bnb_4bit_use_double_quant=True, # saves a bit more VRAM
41
+ )
42
+
43
+ model = AutoModelForCausalLM.from_pretrained(
44
+ MODEL_NAME,
45
+ quantization_config=bnb_config,
46
+ device_map="auto",
47
+ attn_implementation="eager", # avoids flash-attn issues on Colab
48
+ )
49
+ model.config.use_cache = False # required for gradient checkpointing
50
+
51
+ # ─────────────────────────────────────────────
52
+ # LORA CONFIG
53
+ # ─────────────────────────────────────────────
54
+ lora_config = LoraConfig(
55
+ r=8, # smaller r is fine for simple extraction
56
+ lora_alpha=16,
57
+ lora_dropout=0.05,
58
+ bias="none",
59
+ task_type="CAUSAL_LM",
60
+ # Explicitly target attention + MLP layers for Qwen2.5
61
+ target_modules=[
62
+ "q_proj", "k_proj", "v_proj", "o_proj",
63
+ "gate_proj", "up_proj", "down_proj"
64
+ ],
65
+ )
66
+
67
+ # ─────────────────────────────────────────────
68
+ # FORMAT FUNCTION
69
+ # ─────────────────────────────────────────────
70
+ def format_example(example):
71
+ """Apply Qwen2.5 chat template to each training example."""
72
+ return tokenizer.apply_chat_template(
73
+ example["messages"],
74
+ tokenize=False,
75
+ add_generation_prompt=False
76
+ )
77
+
78
+ # ─────────────────────────────────────────────
79
+ # TRAINING ARGS
80
+ # ─────────────────────────────────────────────
81
+ training_args = TrainingArguments(
82
+ output_dir=OUTPUT_DIR,
83
+ per_device_train_batch_size=1,
84
+ gradient_accumulation_steps=4, # effective batch = 4
85
+ gradient_checkpointing=True, # saves VRAM
86
+ learning_rate=2e-4,
87
+ num_train_epochs=3,
88
+ lr_scheduler_type="cosine", # smooth decay
89
+ warmup_ratio=0.05, # 5% warmup steps
90
+ logging_steps=10,
91
+ eval_strategy="epoch", # evaluate after each epoch
92
+ save_strategy="epoch", # save checkpoint each epoch
93
+ save_total_limit=2, # keep only last 2 checkpoints
94
+ load_best_model_at_end=True,
95
+ metric_for_best_model="eval_loss",
96
+ fp16=False,
97
+ bf16=True, # use bfloat16 if your GPU supports it
98
+ report_to="none", # set to "wandb" if you want tracking
99
+ )
100
+
101
+ # ────────────────────────────���────────────────
102
+ # TRAINER
103
+ # ─────────────────────────────────────────────
104
+ trainer = SFTTrainer(
105
+ model=model,
106
+ train_dataset=dataset["train"],
107
+ eval_dataset=dataset["validation"],
108
+ peft_config=lora_config,
109
+ args=training_args,
110
+ formatting_func=format_example,
111
+ max_seq_length=512, # extraction tasks are short
112
+ )
113
+
114
+ trainer.train()
115
+ trainer.save_model(OUTPUT_DIR)
116
+ tokenizer.save_pretrained(OUTPUT_DIR)
117
+ print(f"✅ Adapter saved to {OUTPUT_DIR}")