File size: 5,926 Bytes
1642b2a 31b65fb 1642b2a 745965e 1642b2a 4778444 1642b2a 0ce3e14 1642b2a 3686614 1642b2a 3686614 8c55461 48301ee 8c55461 48301ee 8c55461 0ce3e14 1642b2a 8c55461 745965e 48301ee 1642b2a cb6f87d 1642b2a 745965e 31b65fb 745965e 1642b2a 31b65fb 1642b2a 938e32b 1642b2a 31b65fb 1642b2a aa1a19d 1642b2a 31b65fb 1642b2a 745965e 1642b2a 745965e 1642b2a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 | """
Full fine-tuning script with aggressive memory optimizations:
Model: google/gemma-2-2b-it
Dataset: talkmap/telecom-conversation-corpus
"""
import os
from collections import defaultdict
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import SFTTrainer, SFTConfig
import trackio
import torch
# ------------------------------------------------------------------
# Config
# ------------------------------------------------------------------
MODEL_ID = "google/gemma-2-2b-it"
DATASET_ID = "talkmap/telecom-conversation-corpus"
OUTPUT_DIR = "./gemma-2b-it-telecom"
HUB_MODEL_ID = "ligaments-dev/gemma-2b-it-telecom"
MAX_CONVERSATIONS = 10000
MAX_TURNS = 6
MAX_SEQ_LENGTH = 512
# ------------------------------------------------------------------
# Trackio monitoring
# ------------------------------------------------------------------
trackio.init(
project="gemma-telecom-finetune",
name="gemma-2b-it-full-sft",
)
# ------------------------------------------------------------------
# Load dataset
# ------------------------------------------------------------------
print("Loading dataset...")
ds = load_dataset(DATASET_ID, split="train")
print(f"Rows: {len(ds)}, Columns: {ds.column_names}")
print("Grouping conversations...")
conversations = defaultdict(list)
for row in ds:
text = row["text"] if row["text"] is not None else ""
conversations[row["conversation_id"]].append({
"speaker": row["speaker"],
"date_time": row["date_time"],
"text": text,
})
for conv_id in conversations:
conversations[conv_id].sort(key=lambda x: x["date_time"])
print("Converting to messages format...")
messages_data = []
for conv_id, turns in conversations.items():
turns = turns[:MAX_TURNS]
messages = []
current_role = None
current_content = []
for turn in turns:
role = "user" if turn["speaker"] == "client" else "assistant"
if role == current_role:
current_content.append(turn["text"])
else:
if current_role is not None:
messages.append({"role": current_role, "content": "\n".join(current_content)})
current_role = role
current_content = [turn["text"]]
if current_role is not None:
messages.append({"role": current_role, "content": "\n".join(current_content)})
if not messages or messages[0]["role"] != "user":
continue
valid = True
for i, msg in enumerate(messages):
expected_role = "user" if i % 2 == 0 else "assistant"
if msg["role"] != expected_role:
valid = False
break
if not valid:
continue
if messages[-1]["role"] != "assistant":
continue
messages_data.append({"messages": messages})
if len(messages_data) >= MAX_CONVERSATIONS:
break
print(f"Total conversations: {len(messages_data)}")
# ------------------------------------------------------------------
# Tokenizer
# ------------------------------------------------------------------
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
# ------------------------------------------------------------------
# Pre-tokenize dataset with truncation
# ------------------------------------------------------------------
print("Pre-tokenizing dataset...")
def apply_and_tokenize(example):
try:
text = tokenizer.apply_chat_template(
example["messages"],
tokenize=False,
add_generation_prompt=False,
)
except Exception:
text = ""
return {"text": text}
raw_dataset = Dataset.from_list(messages_data)
raw_dataset = raw_dataset.map(apply_and_tokenize, remove_columns=["messages"])
raw_dataset = raw_dataset.filter(lambda x: len(x["text"]) > 0)
print(f"Dataset after filtering: {len(raw_dataset)}")
# ------------------------------------------------------------------
# Model - load on CPU first to control placement
# ------------------------------------------------------------------
print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
dtype=torch.bfloat16,
device_map="auto",
)
model.gradient_checkpointing_enable()
# Empty cache to free up fragmentation
import gc
gc.collect()
torch.cuda.empty_cache()
# ------------------------------------------------------------------
# Training arguments
# ------------------------------------------------------------------
args = SFTConfig(
output_dir=OUTPUT_DIR,
hub_model_id=HUB_MODEL_ID,
push_to_hub=True,
num_train_epochs=1,
per_device_train_batch_size=1,
gradient_accumulation_steps=8,
learning_rate=2e-5,
max_seq_length=MAX_SEQ_LENGTH,
logging_strategy="steps",
logging_steps=10,
logging_first_step=True,
disable_tqdm=True,
save_strategy="epoch",
bf16=True,
gradient_checkpointing=True,
report_to=["trackio"],
remove_unused_columns=False,
)
# ------------------------------------------------------------------
# Trainer
# ------------------------------------------------------------------
print("Initializing SFTTrainer...")
trainer = SFTTrainer(
model=model,
args=args,
train_dataset=raw_dataset,
processing_class=tokenizer,
)
# ------------------------------------------------------------------
# Train
# ------------------------------------------------------------------
print("Starting training...")
trainer.train()
# ------------------------------------------------------------------
# Save & Push
# ------------------------------------------------------------------
print("Saving and pushing to hub...")
trainer.save_model(OUTPUT_DIR)
trainer.push_to_hub()
print(f"Done! Model at https://huggingface.co/{HUB_MODEL_ID}")
|