ligaments-dev's picture
Add training script
4e9803a verified
import os
import sys
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import SFTTrainer, SFTConfig
import trackio
# ---------------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------------
MODEL_ID = "google/gemma-2b-it"
DATASET_ID = "talkmap/telecom-conversation-corpus"
OUTPUT_REPO = "ligaments-dev/gemma-2b-telecom-sft"
MAX_LENGTH = 2048
# ---------------------------------------------------------------------------
# Logging / tracking
# ---------------------------------------------------------------------------
trackio.init(project="gemma-telecom-sft")
# ---------------------------------------------------------------------------
# Load tokenizer
# ---------------------------------------------------------------------------
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# ---------------------------------------------------------------------------
# Preprocess dataset: group rows by conversation_id into messages
# ---------------------------------------------------------------------------
print("Loading dataset...")
raw = load_dataset(DATASET_ID, split="train")
# The dataset has columns: conversation_id, speaker, date_time, text
# We group by conversation_id and build a single messages list per conversation.
print("Grouping conversations...")
conv_map = {}
for ex in raw:
cid = ex["conversation_id"]
if cid not in conv_map:
conv_map[cid] = []
role = "user" if ex["speaker"] == "client" else "assistant"
conv_map[cid].append({"role": role, "content": ex["text"]})
# Build a huggingface Dataset from the grouped conversations
messages_list = []
for cid, msgs in conv_map.items():
# Keep only alternating user/assistant; skip system-like turns if any.
# Gemma-2b-it chat template expects user/assistant.
messages_list.append({"messages": msgs})
train_dataset = Dataset.from_list(messages_list)
print(f"Prepared {len(train_dataset)} conversation examples.")
# Save processed dataset so we don't recompute on resume
processed_path = "/tmp/telecom_processed"
train_dataset.save_to_disk(processed_path)
print(f"Saved processed dataset to {processed_path}")
# ---------------------------------------------------------------------------
# Load model
# ---------------------------------------------------------------------------
print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype="auto",
device_map="auto",
)
# ---------------------------------------------------------------------------
# Training arguments
# ---------------------------------------------------------------------------
training_args = SFTConfig(
output_dir="/tmp/gemma-telecom-sft",
num_train_epochs=3,
per_device_train_batch_size=2,
gradient_accumulation_steps=8,
learning_rate=2e-5,
warmup_ratio=0.1,
lr_scheduler_type="cosine",
max_length=MAX_LENGTH,
packing=True,
gradient_checkpointing=True,
bf16=True,
logging_steps=10,
save_strategy="epoch",
push_to_hub=True,
hub_model_id=OUTPUT_REPO,
hub_private_repo=False,
disable_tqdm=True,
logging_strategy="steps",
logging_first_step=True,
report_to=["trackio"],
)
# ---------------------------------------------------------------------------
# Trainer
# ---------------------------------------------------------------------------
print("Initializing SFTTrainer...")
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
processing_class=tokenizer,
)
# ---------------------------------------------------------------------------
# Train
# ---------------------------------------------------------------------------
print("Starting training...")
trainer.train()
# ---------------------------------------------------------------------------
# Push to hub
# ---------------------------------------------------------------------------
print("Pushing model to hub...")
trainer.push_to_hub(commit_message="Full SFT on telecom conversation corpus")
print("Done!")