ligaments-dev's picture
Increase grad_accum to 8, add gc.collect and empty_cache before training
31b65fb verified
"""
Full fine-tuning script with aggressive memory optimizations:
Model: google/gemma-2-2b-it
Dataset: talkmap/telecom-conversation-corpus
"""
import os
from collections import defaultdict
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import SFTTrainer, SFTConfig
import trackio
import torch
# ------------------------------------------------------------------
# Config
# ------------------------------------------------------------------
MODEL_ID = "google/gemma-2-2b-it"
DATASET_ID = "talkmap/telecom-conversation-corpus"
OUTPUT_DIR = "./gemma-2b-it-telecom"
HUB_MODEL_ID = "ligaments-dev/gemma-2b-it-telecom"
MAX_CONVERSATIONS = 10000
MAX_TURNS = 6
MAX_SEQ_LENGTH = 512
# ------------------------------------------------------------------
# Trackio monitoring
# ------------------------------------------------------------------
trackio.init(
project="gemma-telecom-finetune",
name="gemma-2b-it-full-sft",
)
# ------------------------------------------------------------------
# Load dataset
# ------------------------------------------------------------------
print("Loading dataset...")
ds = load_dataset(DATASET_ID, split="train")
print(f"Rows: {len(ds)}, Columns: {ds.column_names}")
print("Grouping conversations...")
conversations = defaultdict(list)
for row in ds:
text = row["text"] if row["text"] is not None else ""
conversations[row["conversation_id"]].append({
"speaker": row["speaker"],
"date_time": row["date_time"],
"text": text,
})
for conv_id in conversations:
conversations[conv_id].sort(key=lambda x: x["date_time"])
print("Converting to messages format...")
messages_data = []
for conv_id, turns in conversations.items():
turns = turns[:MAX_TURNS]
messages = []
current_role = None
current_content = []
for turn in turns:
role = "user" if turn["speaker"] == "client" else "assistant"
if role == current_role:
current_content.append(turn["text"])
else:
if current_role is not None:
messages.append({"role": current_role, "content": "\n".join(current_content)})
current_role = role
current_content = [turn["text"]]
if current_role is not None:
messages.append({"role": current_role, "content": "\n".join(current_content)})
if not messages or messages[0]["role"] != "user":
continue
valid = True
for i, msg in enumerate(messages):
expected_role = "user" if i % 2 == 0 else "assistant"
if msg["role"] != expected_role:
valid = False
break
if not valid:
continue
if messages[-1]["role"] != "assistant":
continue
messages_data.append({"messages": messages})
if len(messages_data) >= MAX_CONVERSATIONS:
break
print(f"Total conversations: {len(messages_data)}")
# ------------------------------------------------------------------
# Tokenizer
# ------------------------------------------------------------------
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
# ------------------------------------------------------------------
# Pre-tokenize dataset with truncation
# ------------------------------------------------------------------
print("Pre-tokenizing dataset...")
def apply_and_tokenize(example):
try:
text = tokenizer.apply_chat_template(
example["messages"],
tokenize=False,
add_generation_prompt=False,
)
except Exception:
text = ""
return {"text": text}
raw_dataset = Dataset.from_list(messages_data)
raw_dataset = raw_dataset.map(apply_and_tokenize, remove_columns=["messages"])
raw_dataset = raw_dataset.filter(lambda x: len(x["text"]) > 0)
print(f"Dataset after filtering: {len(raw_dataset)}")
# ------------------------------------------------------------------
# Model - load on CPU first to control placement
# ------------------------------------------------------------------
print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
dtype=torch.bfloat16,
device_map="auto",
)
model.gradient_checkpointing_enable()
# Empty cache to free up fragmentation
import gc
gc.collect()
torch.cuda.empty_cache()
# ------------------------------------------------------------------
# Training arguments
# ------------------------------------------------------------------
args = SFTConfig(
output_dir=OUTPUT_DIR,
hub_model_id=HUB_MODEL_ID,
push_to_hub=True,
num_train_epochs=1,
per_device_train_batch_size=1,
gradient_accumulation_steps=8,
learning_rate=2e-5,
max_seq_length=MAX_SEQ_LENGTH,
logging_strategy="steps",
logging_steps=10,
logging_first_step=True,
disable_tqdm=True,
save_strategy="epoch",
bf16=True,
gradient_checkpointing=True,
report_to=["trackio"],
remove_unused_columns=False,
)
# ------------------------------------------------------------------
# Trainer
# ------------------------------------------------------------------
print("Initializing SFTTrainer...")
trainer = SFTTrainer(
model=model,
args=args,
train_dataset=raw_dataset,
processing_class=tokenizer,
)
# ------------------------------------------------------------------
# Train
# ------------------------------------------------------------------
print("Starting training...")
trainer.train()
# ------------------------------------------------------------------
# Save & Push
# ------------------------------------------------------------------
print("Saving and pushing to hub...")
trainer.save_model(OUTPUT_DIR)
trainer.push_to_hub()
print(f"Done! Model at https://huggingface.co/{HUB_MODEL_ID}")