AmiDwivedi commited on
Commit
5fc7f70
·
verified ·
1 Parent(s): ab7b4f1

Upload hf_train_lr2e4.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. hf_train_lr2e4.py +92 -0
hf_train_lr2e4.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # dependencies = ["trl>=0.12.0", "peft>=0.14.0", "trackio", "bitsandbytes", "accelerate"]
3
+ # ///
4
+
5
+ """
6
+ Underwood SFT Training - Learning Rate 2e-4
7
+ Fine-tunes Gemma 3 4B with QLoRA on strategic advisor conversations
8
+ """
9
+
10
+ from datasets import load_dataset
11
+ from peft import LoraConfig
12
+ from trl import SFTTrainer, SFTConfig
13
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
14
+ import torch
15
+ import trackio
16
+
17
+ # QLoRA config
18
+ bnb_config = BitsAndBytesConfig(
19
+ load_in_4bit=True,
20
+ bnb_4bit_quant_type="nf4",
21
+ bnb_4bit_compute_dtype=torch.bfloat16,
22
+ bnb_4bit_use_double_quant=True,
23
+ )
24
+
25
+ # Load model
26
+ model = AutoModelForCausalLM.from_pretrained(
27
+ "google/gemma-3-4b-it",
28
+ quantization_config=bnb_config,
29
+ device_map="auto",
30
+ torch_dtype=torch.bfloat16,
31
+ attn_implementation="eager",
32
+ )
33
+
34
+ tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-4b-it")
35
+ tokenizer.pad_token = tokenizer.eos_token
36
+ tokenizer.padding_side = "right"
37
+
38
+ # Load dataset
39
+ dataset = load_dataset("AmiDwivedi/underwood-conversations")
40
+
41
+ # LoRA config (matching local setup)
42
+ lora_config = LoraConfig(
43
+ r=128,
44
+ lora_alpha=256,
45
+ lora_dropout=0.05,
46
+ bias="none",
47
+ task_type="CAUSAL_LM",
48
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
49
+ )
50
+
51
+ # Training config
52
+ training_args = SFTConfig(
53
+ output_dir="underwood-lr2e4",
54
+ num_train_epochs=10,
55
+ per_device_train_batch_size=2,
56
+ per_device_eval_batch_size=2,
57
+ gradient_accumulation_steps=8,
58
+ learning_rate=2e-4,
59
+ weight_decay=0.01,
60
+ warmup_ratio=0.03,
61
+ lr_scheduler_type="cosine",
62
+ logging_steps=10,
63
+ eval_strategy="steps",
64
+ eval_steps=50,
65
+ save_strategy="steps",
66
+ save_steps=100,
67
+ save_total_limit=2,
68
+ bf16=True,
69
+ max_length=2048,
70
+ packing=False,
71
+ gradient_checkpointing=True,
72
+ push_to_hub=True,
73
+ hub_model_id="AmiDwivedi/underwood-lr2e4",
74
+ hub_strategy="every_save",
75
+ report_to="trackio",
76
+ run_name="underwood-lr2e4",
77
+ )
78
+
79
+ # Create trainer
80
+ trainer = SFTTrainer(
81
+ model=model,
82
+ args=training_args,
83
+ train_dataset=dataset["train"],
84
+ eval_dataset=dataset["validation"],
85
+ peft_config=lora_config,
86
+ processing_class=tokenizer,
87
+ )
88
+
89
+ # Train
90
+ trainer.train()
91
+ trainer.push_to_hub()
92
+ print("Training complete! Model pushed to AmiDwivedi/underwood-lr2e4")