ligaments-dev commited on
Commit
873cfbf
Β·
verified Β·
1 Parent(s): d977c0e

Add training script for Gemma-2B telco fine-tuning

Browse files
Files changed (1) hide show
  1. train_gemma_telco.py +165 -0
train_gemma_telco.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Full fine-tuning (QLoRA 4-bit NF4) of Google Gemma-2B-IT on the Bitext telco chatbot dataset.
3
+ Deploys the trained model to Hugging Face Hub.
4
+
5
+ Note: True 2-bit training is not supported by standard libraries (bitsandbytes only supports 4-bit/8-bit).
6
+ We use 4-bit NF4 (NormalFloat4) which is the industry-standard memory-efficient quantization approach.
7
+ This provides ~4x memory savings compared to FP16, enabling fine-tuning on consumer GPUs.
8
+ """
9
+
10
+ import os
11
+ import torch
12
+ from datasets import load_dataset
13
+ from transformers import (
14
+ AutoModelForCausalLM,
15
+ AutoTokenizer,
16
+ BitsAndBytesConfig,
17
+ )
18
+ from peft import LoraConfig, get_peft_model
19
+ from trl import SFTTrainer, SFTConfig
20
+ import trackio
21
+
22
+ # ── Configuration ─────────────────────────────────────────────────────────────
23
+ MODEL_ID = "google/gemma-2b-it"
24
+ DATASET_ID = "bitext/Bitext-telco-llm-chatbot-training-dataset"
25
+ HUB_MODEL_ID = "ligaments-dev/gemma-2b-telco-sft"
26
+ OUTPUT_DIR = "./gemma-telco-sft-output"
27
+
28
+ # ── Initialize Trackio for monitoring ──────────────────────────────────────
29
+ trackio.init(
30
+ project="gemma-telco-sft",
31
+ name="gemma-2b-telco-qlora-4bit",
32
+ config={
33
+ "model": MODEL_ID,
34
+ "dataset": DATASET_ID,
35
+ "quantization": "4bit-nf4",
36
+ "lora_r": 16,
37
+ "lora_alpha": 32,
38
+ "epochs": 3,
39
+ "learning_rate": 2e-4,
40
+ },
41
+ )
42
+
43
+ # ── 1. Load & format dataset ───────────────────────────────────────────────
44
+ print("Loading dataset...")
45
+ dataset = load_dataset(DATASET_ID, split="train")
46
+ print(f"Dataset loaded: {len(dataset)} examples")
47
+
48
+
49
+ def format_to_messages(example):
50
+ """Convert instruction/response to conversational messages format."""
51
+ return {
52
+ "messages": [
53
+ {"role": "user", "content": example["instruction"]},
54
+ {"role": "assistant", "content": example["response"]},
55
+ ]
56
+ }
57
+
58
+
59
+ dataset = dataset.map(format_to_messages, remove_columns=dataset.column_names)
60
+ print(f"Formatted dataset sample: {dataset[0]}")
61
+
62
+ # ── 2. Load tokenizer ──────────────────────────────────────────────────────
63
+ print("Loading tokenizer...")
64
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
65
+ if tokenizer.pad_token is None:
66
+ tokenizer.pad_token = tokenizer.eos_token
67
+ tokenizer.pad_token_id = tokenizer.eos_token_id
68
+
69
+ # ── 3. Quantization config (4-bit NF4 β€” closest practical to 2-bit) ────────
70
+ print("Setting up 4-bit NF4 quantization...")
71
+ bnb_config = BitsAndBytesConfig(
72
+ load_in_4bit=True,
73
+ bnb_4bit_quant_type="nf4", # NormalFloat4 β€” optimal for weight distributions
74
+ bnb_4bit_use_double_quant=True, # Nested quantization saves more memory
75
+ bnb_4bit_compute_dtype=torch.bfloat16, # Compute in BF16 for stability
76
+ )
77
+
78
+ # ── 4. Load model with quantization ──────────────────────────────────────────
79
+ print("Loading model with 4-bit quantization...")
80
+ model = AutoModelForCausalLM.from_pretrained(
81
+ MODEL_ID,
82
+ quantization_config=bnb_config,
83
+ device_map="auto",
84
+ torch_dtype=torch.bfloat16,
85
+ trust_remote_code=True,
86
+ )
87
+ model.config.use_cache = False # Required for gradient checkpointing
88
+ print(f"Model loaded. Trainable params info will be shown after LoRA setup.")
89
+
90
+ # ── 5. LoRA config (PEFT adapters for efficient fine-tuning) ─────────────────
91
+ print("Applying LoRA adapters...")
92
+ peft_config = LoraConfig(
93
+ r=16, # LoRA rank
94
+ lora_alpha=32, # Scaling factor
95
+ target_modules="all-linear", # Auto-detect all linear layers
96
+ lora_dropout=0.05,
97
+ bias="none",
98
+ task_type="CAUSAL_LM",
99
+ )
100
+
101
+ model = get_peft_model(model, peft_config)
102
+ model.print_trainable_parameters()
103
+
104
+ # ── 6. Training config ───────────────────────────────────────────────────────
105
+ print("Configuring training arguments...")
106
+ training_args = SFTConfig(
107
+ output_dir=OUTPUT_DIR,
108
+ per_device_train_batch_size=4,
109
+ gradient_accumulation_steps=4,
110
+ num_train_epochs=3,
111
+ learning_rate=2e-4,
112
+ lr_scheduler_type="cosine",
113
+ warmup_ratio=0.1,
114
+ optim="paged_adamw_8bit", # Paged optimizer for memory efficiency
115
+ bf16=True,
116
+ gradient_checkpointing=True, # Trade compute for memory
117
+ logging_strategy="steps",
118
+ logging_steps=10,
119
+ logging_first_step=True,
120
+ save_strategy="epoch",
121
+ save_total_limit=2,
122
+ push_to_hub=True,
123
+ hub_model_id=HUB_MODEL_ID,
124
+ hub_private_repo=False,
125
+ report_to=["trackio"],
126
+ max_length=512,
127
+ packing=False,
128
+ disable_tqdm=True,
129
+ seed=42,
130
+ )
131
+
132
+ # ── 7. Initialize trainer ──────────────────────────────────────────────────
133
+ print("Initializing SFTTrainer...")
134
+ trainer = SFTTrainer(
135
+ model=model,
136
+ args=training_args,
137
+ train_dataset=dataset,
138
+ processing_class=tokenizer,
139
+ peft_config=peft_config,
140
+ )
141
+
142
+ # ── 8. Train ─────────────────────────────────────────────────────────────────
143
+ print("Starting training...")
144
+ trainer.train()
145
+
146
+ # ── 9. Save & deploy ───────────────────────────────────────────────────────────
147
+ print("Saving final model...")
148
+ trainer.save_model(OUTPUT_DIR)
149
+
150
+ print("Pushing to Hugging Face Hub...")
151
+ trainer.push_to_hub(
152
+ commit_message="Fine-tuned Gemma-2B-IT on Bitext telco chatbot dataset (QLoRA 4-bit NF4)"
153
+ )
154
+
155
+ print("Training complete! Model deployed to:")
156
+ print(f" https://huggingface.co/{HUB_MODEL_ID}")
157
+
158
+ # ── 10. Merge adapters for inference (optional but recommended) ──────────────
159
+ print("Merging LoRA adapters with base model for optimized inference...")
160
+ merged_model = model.merge_and_unload()
161
+ merged_model.push_to_hub(
162
+ f"{HUB_MODEL_ID}-merged",
163
+ commit_message="Merged Gemma-2B-IT + LoRA adapters for inference"
164
+ )
165
+ print(f"Merged model deployed to: https://huggingface.co/{HUB_MODEL_ID}-merged")