import os import sys from datasets import load_dataset, Dataset from transformers import AutoTokenizer, AutoModelForCausalLM from trl import SFTTrainer, SFTConfig import trackio # --------------------------------------------------------------------------- # Config # --------------------------------------------------------------------------- MODEL_ID = "google/gemma-2b-it" DATASET_ID = "talkmap/telecom-conversation-corpus" OUTPUT_REPO = "ligaments-dev/gemma-2b-telecom-sft" MAX_LENGTH = 2048 # --------------------------------------------------------------------------- # Logging / tracking # --------------------------------------------------------------------------- trackio.init(project="gemma-telecom-sft") # --------------------------------------------------------------------------- # Load tokenizer # --------------------------------------------------------------------------- print("Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # --------------------------------------------------------------------------- # Preprocess dataset: group rows by conversation_id into messages # --------------------------------------------------------------------------- print("Loading dataset...") raw = load_dataset(DATASET_ID, split="train") # The dataset has columns: conversation_id, speaker, date_time, text # We group by conversation_id and build a single messages list per conversation. print("Grouping conversations...") conv_map = {} for ex in raw: cid = ex["conversation_id"] if cid not in conv_map: conv_map[cid] = [] role = "user" if ex["speaker"] == "client" else "assistant" conv_map[cid].append({"role": role, "content": ex["text"]}) # Build a huggingface Dataset from the grouped conversations messages_list = [] for cid, msgs in conv_map.items(): # Keep only alternating user/assistant; skip system-like turns if any. # Gemma-2b-it chat template expects user/assistant. messages_list.append({"messages": msgs}) train_dataset = Dataset.from_list(messages_list) print(f"Prepared {len(train_dataset)} conversation examples.") # Save processed dataset so we don't recompute on resume processed_path = "/tmp/telecom_processed" train_dataset.save_to_disk(processed_path) print(f"Saved processed dataset to {processed_path}") # --------------------------------------------------------------------------- # Load model # --------------------------------------------------------------------------- print("Loading model...") model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype="auto", device_map="auto", ) # --------------------------------------------------------------------------- # Training arguments # --------------------------------------------------------------------------- training_args = SFTConfig( output_dir="/tmp/gemma-telecom-sft", num_train_epochs=3, per_device_train_batch_size=2, gradient_accumulation_steps=8, learning_rate=2e-5, warmup_ratio=0.1, lr_scheduler_type="cosine", max_length=MAX_LENGTH, packing=True, gradient_checkpointing=True, bf16=True, logging_steps=10, save_strategy="epoch", push_to_hub=True, hub_model_id=OUTPUT_REPO, hub_private_repo=False, disable_tqdm=True, logging_strategy="steps", logging_first_step=True, report_to=["trackio"], ) # --------------------------------------------------------------------------- # Trainer # --------------------------------------------------------------------------- print("Initializing SFTTrainer...") trainer = SFTTrainer( model=model, args=training_args, train_dataset=train_dataset, processing_class=tokenizer, ) # --------------------------------------------------------------------------- # Train # --------------------------------------------------------------------------- print("Starting training...") trainer.train() # --------------------------------------------------------------------------- # Push to hub # --------------------------------------------------------------------------- print("Pushing model to hub...") trainer.push_to_hub(commit_message="Full SFT on telecom conversation corpus") print("Done!")