gemma-telecom-training / train_telecom_gemma.py
ligaments-dev's picture
Add training script for gemma-2b-it telecom full fine-tuning
fee27f5 verified
"""
Full fine-tuning script:
Model: google/gemma-2-2b-it
Dataset: talkmap/telecom-conversation-corpus
Converts turn-based telecom dialogues -> conversational messages format for SFT.
"""
import os
from collections import defaultdict
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import SFTTrainer, SFTConfig
import trackio
import torch
# ------------------------------------------------------------------
# Config
# ------------------------------------------------------------------
MODEL_ID = "google/gemma-2-2b-it"
DATASET_ID = "talkmap/telecom-conversation-corpus"
OUTPUT_DIR = "./gemma-2b-it-telecom"
HUB_MODEL_ID = "ligaments-dev/gemma-2b-it-telecom"
MAX_SEQ_LENGTH = 2048
# ------------------------------------------------------------------
# Trackio monitoring
# ------------------------------------------------------------------
trackio.init(
project="gemma-telecom-finetune",
name="gemma-2b-it-full-sft",
)
# ------------------------------------------------------------------
# Load dataset
# ------------------------------------------------------------------
print("Loading dataset...")
ds = load_dataset(DATASET_ID, split="train")
print(f"Rows: {len(ds)}, Columns: {ds.column_names}")
# Group rows by conversation_id and sort by date_time
print("Grouping conversations...")
conversations = defaultdict(list)
for row in ds:
conversations[row["conversation_id"]].append(row)
for conv_id in conversations:
conversations[conv_id].sort(key=lambda x: x["date_time"])
# Convert each conversation into messages format
print("Converting to messages format...")
messages_data = []
for conv_id, turns in conversations.items():
messages = []
for turn in turns:
role = "user" if turn["speaker"] == "client" else "assistant"
messages.append({"role": role, "content": turn["text"]})
messages_data.append({"messages": messages})
train_dataset = Dataset.from_list(messages_data)
print(f"Total conversations: {len(train_dataset)}")
# ------------------------------------------------------------------
# Tokenizer
# ------------------------------------------------------------------
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
# ------------------------------------------------------------------
# Model
# ------------------------------------------------------------------
print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
model.gradient_checkpointing_enable()
# ------------------------------------------------------------------
# Training arguments
# ------------------------------------------------------------------
args = SFTConfig(
output_dir=OUTPUT_DIR,
hub_model_id=HUB_MODEL_ID,
push_to_hub=True,
num_train_epochs=3,
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
learning_rate=2e-5,
max_seq_length=MAX_SEQ_LENGTH,
logging_strategy="steps",
logging_steps=10,
logging_first_step=True,
disable_tqdm=True,
save_strategy="epoch",
bf16=True,
gradient_checkpointing=True,
report_to=["trackio"],
remove_unused_columns=False,
)
# ------------------------------------------------------------------
# Trainer
# ------------------------------------------------------------------
print("Initializing SFTTrainer...")
trainer = SFTTrainer(
model=model,
args=args,
train_dataset=train_dataset,
processing_class=tokenizer,
)
# ------------------------------------------------------------------
# Train
# ------------------------------------------------------------------
print("Starting training...")
trainer.train()
# ------------------------------------------------------------------
# Save & Push
# ------------------------------------------------------------------
print("Saving and pushing to hub...")
trainer.save_model(OUTPUT_DIR)
trainer.push_to_hub()
print(f"Done! Model at https://huggingface.co/{HUB_MODEL_ID}")