| """ |
| Full fine-tuning script: |
| Model: google/gemma-2-2b-it |
| Dataset: talkmap/telecom-conversation-corpus |
| Converts turn-based telecom dialogues -> conversational messages format for SFT. |
| """ |
| import os |
| from collections import defaultdict |
|
|
| from datasets import load_dataset, Dataset |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
| from trl import SFTTrainer, SFTConfig |
| import trackio |
| import torch |
|
|
| |
| |
| |
| MODEL_ID = "google/gemma-2-2b-it" |
| DATASET_ID = "talkmap/telecom-conversation-corpus" |
| OUTPUT_DIR = "./gemma-2b-it-telecom" |
| HUB_MODEL_ID = "ligaments-dev/gemma-2b-it-telecom" |
| MAX_SEQ_LENGTH = 2048 |
|
|
| |
| |
| |
| trackio.init( |
| project="gemma-telecom-finetune", |
| name="gemma-2b-it-full-sft", |
| ) |
|
|
| |
| |
| |
| print("Loading dataset...") |
| ds = load_dataset(DATASET_ID, split="train") |
| print(f"Rows: {len(ds)}, Columns: {ds.column_names}") |
|
|
| |
| print("Grouping conversations...") |
| conversations = defaultdict(list) |
| for row in ds: |
| conversations[row["conversation_id"]].append(row) |
|
|
| for conv_id in conversations: |
| conversations[conv_id].sort(key=lambda x: x["date_time"]) |
|
|
| |
| print("Converting to messages format...") |
| messages_data = [] |
| for conv_id, turns in conversations.items(): |
| messages = [] |
| for turn in turns: |
| role = "user" if turn["speaker"] == "client" else "assistant" |
| messages.append({"role": role, "content": turn["text"]}) |
| messages_data.append({"messages": messages}) |
|
|
| train_dataset = Dataset.from_list(messages_data) |
| print(f"Total conversations: {len(train_dataset)}") |
|
|
| |
| |
| |
| print("Loading tokenizer...") |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| tokenizer.pad_token_id = tokenizer.eos_token_id |
|
|
| |
| |
| |
| print("Loading model...") |
| model = AutoModelForCausalLM.from_pretrained( |
| MODEL_ID, |
| torch_dtype=torch.bfloat16, |
| device_map="auto", |
| trust_remote_code=True, |
| ) |
|
|
| model.gradient_checkpointing_enable() |
|
|
| |
| |
| |
| args = SFTConfig( |
| output_dir=OUTPUT_DIR, |
| hub_model_id=HUB_MODEL_ID, |
| push_to_hub=True, |
| num_train_epochs=3, |
| per_device_train_batch_size=1, |
| gradient_accumulation_steps=4, |
| learning_rate=2e-5, |
| max_seq_length=MAX_SEQ_LENGTH, |
| logging_strategy="steps", |
| logging_steps=10, |
| logging_first_step=True, |
| disable_tqdm=True, |
| save_strategy="epoch", |
| bf16=True, |
| gradient_checkpointing=True, |
| report_to=["trackio"], |
| remove_unused_columns=False, |
| ) |
|
|
| |
| |
| |
| print("Initializing SFTTrainer...") |
| trainer = SFTTrainer( |
| model=model, |
| args=args, |
| train_dataset=train_dataset, |
| processing_class=tokenizer, |
| ) |
|
|
| |
| |
| |
| print("Starting training...") |
| trainer.train() |
|
|
| |
| |
| |
| print("Saving and pushing to hub...") |
| trainer.save_model(OUTPUT_DIR) |
| trainer.push_to_hub() |
|
|
| print(f"Done! Model at https://huggingface.co/{HUB_MODEL_ID}") |
|
|