""" Full fine-tuning script with aggressive memory optimizations: Model: google/gemma-2-2b-it Dataset: talkmap/telecom-conversation-corpus """ import os from collections import defaultdict from datasets import load_dataset, Dataset from transformers import AutoTokenizer, AutoModelForCausalLM from trl import SFTTrainer, SFTConfig import trackio import torch # ------------------------------------------------------------------ # Config # ------------------------------------------------------------------ MODEL_ID = "google/gemma-2-2b-it" DATASET_ID = "talkmap/telecom-conversation-corpus" OUTPUT_DIR = "./gemma-2b-it-telecom" HUB_MODEL_ID = "ligaments-dev/gemma-2b-it-telecom" MAX_CONVERSATIONS = 10000 MAX_TURNS = 6 MAX_SEQ_LENGTH = 512 # ------------------------------------------------------------------ # Trackio monitoring # ------------------------------------------------------------------ trackio.init( project="gemma-telecom-finetune", name="gemma-2b-it-full-sft", ) # ------------------------------------------------------------------ # Load dataset # ------------------------------------------------------------------ print("Loading dataset...") ds = load_dataset(DATASET_ID, split="train") print(f"Rows: {len(ds)}, Columns: {ds.column_names}") print("Grouping conversations...") conversations = defaultdict(list) for row in ds: text = row["text"] if row["text"] is not None else "" conversations[row["conversation_id"]].append({ "speaker": row["speaker"], "date_time": row["date_time"], "text": text, }) for conv_id in conversations: conversations[conv_id].sort(key=lambda x: x["date_time"]) print("Converting to messages format...") messages_data = [] for conv_id, turns in conversations.items(): turns = turns[:MAX_TURNS] messages = [] current_role = None current_content = [] for turn in turns: role = "user" if turn["speaker"] == "client" else "assistant" if role == current_role: current_content.append(turn["text"]) else: if current_role is not None: messages.append({"role": current_role, "content": "\n".join(current_content)}) current_role = role current_content = [turn["text"]] if current_role is not None: messages.append({"role": current_role, "content": "\n".join(current_content)}) if not messages or messages[0]["role"] != "user": continue valid = True for i, msg in enumerate(messages): expected_role = "user" if i % 2 == 0 else "assistant" if msg["role"] != expected_role: valid = False break if not valid: continue if messages[-1]["role"] != "assistant": continue messages_data.append({"messages": messages}) if len(messages_data) >= MAX_CONVERSATIONS: break print(f"Total conversations: {len(messages_data)}") # ------------------------------------------------------------------ # Tokenizer # ------------------------------------------------------------------ print("Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token_id = tokenizer.eos_token_id # ------------------------------------------------------------------ # Pre-tokenize dataset with truncation # ------------------------------------------------------------------ print("Pre-tokenizing dataset...") def apply_and_tokenize(example): try: text = tokenizer.apply_chat_template( example["messages"], tokenize=False, add_generation_prompt=False, ) except Exception: text = "" return {"text": text} raw_dataset = Dataset.from_list(messages_data) raw_dataset = raw_dataset.map(apply_and_tokenize, remove_columns=["messages"]) raw_dataset = raw_dataset.filter(lambda x: len(x["text"]) > 0) print(f"Dataset after filtering: {len(raw_dataset)}") # ------------------------------------------------------------------ # Model - load on CPU first to control placement # ------------------------------------------------------------------ print("Loading model...") model = AutoModelForCausalLM.from_pretrained( MODEL_ID, dtype=torch.bfloat16, device_map="auto", ) model.gradient_checkpointing_enable() # Empty cache to free up fragmentation import gc gc.collect() torch.cuda.empty_cache() # ------------------------------------------------------------------ # Training arguments # ------------------------------------------------------------------ args = SFTConfig( output_dir=OUTPUT_DIR, hub_model_id=HUB_MODEL_ID, push_to_hub=True, num_train_epochs=1, per_device_train_batch_size=1, gradient_accumulation_steps=8, learning_rate=2e-5, max_seq_length=MAX_SEQ_LENGTH, logging_strategy="steps", logging_steps=10, logging_first_step=True, disable_tqdm=True, save_strategy="epoch", bf16=True, gradient_checkpointing=True, report_to=["trackio"], remove_unused_columns=False, ) # ------------------------------------------------------------------ # Trainer # ------------------------------------------------------------------ print("Initializing SFTTrainer...") trainer = SFTTrainer( model=model, args=args, train_dataset=raw_dataset, processing_class=tokenizer, ) # ------------------------------------------------------------------ # Train # ------------------------------------------------------------------ print("Starting training...") trainer.train() # ------------------------------------------------------------------ # Save & Push # ------------------------------------------------------------------ print("Saving and pushing to hub...") trainer.save_model(OUTPUT_DIR) trainer.push_to_hub() print(f"Done! Model at https://huggingface.co/{HUB_MODEL_ID}")