| """ |
| Full fine-tuning script with aggressive memory optimizations: |
| Model: google/gemma-2-2b-it |
| Dataset: talkmap/telecom-conversation-corpus |
| """ |
| import os |
| from collections import defaultdict |
|
|
| from datasets import load_dataset, Dataset |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
| from trl import SFTTrainer, SFTConfig |
| import trackio |
| import torch |
|
|
| |
| |
| |
| MODEL_ID = "google/gemma-2-2b-it" |
| DATASET_ID = "talkmap/telecom-conversation-corpus" |
| OUTPUT_DIR = "./gemma-2b-it-telecom" |
| HUB_MODEL_ID = "ligaments-dev/gemma-2b-it-telecom" |
| MAX_CONVERSATIONS = 10000 |
| MAX_TURNS = 6 |
| MAX_SEQ_LENGTH = 512 |
|
|
| |
| |
| |
| trackio.init( |
| project="gemma-telecom-finetune", |
| name="gemma-2b-it-full-sft", |
| ) |
|
|
| |
| |
| |
| print("Loading dataset...") |
| ds = load_dataset(DATASET_ID, split="train") |
| print(f"Rows: {len(ds)}, Columns: {ds.column_names}") |
|
|
| print("Grouping conversations...") |
| conversations = defaultdict(list) |
| for row in ds: |
| text = row["text"] if row["text"] is not None else "" |
| conversations[row["conversation_id"]].append({ |
| "speaker": row["speaker"], |
| "date_time": row["date_time"], |
| "text": text, |
| }) |
|
|
| for conv_id in conversations: |
| conversations[conv_id].sort(key=lambda x: x["date_time"]) |
|
|
| print("Converting to messages format...") |
| messages_data = [] |
| for conv_id, turns in conversations.items(): |
| turns = turns[:MAX_TURNS] |
|
|
| messages = [] |
| current_role = None |
| current_content = [] |
| for turn in turns: |
| role = "user" if turn["speaker"] == "client" else "assistant" |
| if role == current_role: |
| current_content.append(turn["text"]) |
| else: |
| if current_role is not None: |
| messages.append({"role": current_role, "content": "\n".join(current_content)}) |
| current_role = role |
| current_content = [turn["text"]] |
| if current_role is not None: |
| messages.append({"role": current_role, "content": "\n".join(current_content)}) |
|
|
| if not messages or messages[0]["role"] != "user": |
| continue |
|
|
| valid = True |
| for i, msg in enumerate(messages): |
| expected_role = "user" if i % 2 == 0 else "assistant" |
| if msg["role"] != expected_role: |
| valid = False |
| break |
| if not valid: |
| continue |
|
|
| if messages[-1]["role"] != "assistant": |
| continue |
|
|
| messages_data.append({"messages": messages}) |
|
|
| if len(messages_data) >= MAX_CONVERSATIONS: |
| break |
|
|
| print(f"Total conversations: {len(messages_data)}") |
|
|
| |
| |
| |
| print("Loading tokenizer...") |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| tokenizer.pad_token_id = tokenizer.eos_token_id |
|
|
| |
| |
| |
| print("Pre-tokenizing dataset...") |
|
|
| def apply_and_tokenize(example): |
| try: |
| text = tokenizer.apply_chat_template( |
| example["messages"], |
| tokenize=False, |
| add_generation_prompt=False, |
| ) |
| except Exception: |
| text = "" |
| return {"text": text} |
|
|
| raw_dataset = Dataset.from_list(messages_data) |
| raw_dataset = raw_dataset.map(apply_and_tokenize, remove_columns=["messages"]) |
| raw_dataset = raw_dataset.filter(lambda x: len(x["text"]) > 0) |
| print(f"Dataset after filtering: {len(raw_dataset)}") |
|
|
| |
| |
| |
| print("Loading model...") |
| model = AutoModelForCausalLM.from_pretrained( |
| MODEL_ID, |
| dtype=torch.bfloat16, |
| device_map="auto", |
| ) |
|
|
| model.gradient_checkpointing_enable() |
|
|
| |
| import gc |
| gc.collect() |
| torch.cuda.empty_cache() |
|
|
| |
| |
| |
| args = SFTConfig( |
| output_dir=OUTPUT_DIR, |
| hub_model_id=HUB_MODEL_ID, |
| push_to_hub=True, |
| num_train_epochs=1, |
| per_device_train_batch_size=1, |
| gradient_accumulation_steps=8, |
| learning_rate=2e-5, |
| max_seq_length=MAX_SEQ_LENGTH, |
| logging_strategy="steps", |
| logging_steps=10, |
| logging_first_step=True, |
| disable_tqdm=True, |
| save_strategy="epoch", |
| bf16=True, |
| gradient_checkpointing=True, |
| report_to=["trackio"], |
| remove_unused_columns=False, |
| ) |
|
|
| |
| |
| |
| print("Initializing SFTTrainer...") |
| trainer = SFTTrainer( |
| model=model, |
| args=args, |
| train_dataset=raw_dataset, |
| processing_class=tokenizer, |
| ) |
|
|
| |
| |
| |
| print("Starting training...") |
| trainer.train() |
|
|
| |
| |
| |
| print("Saving and pushing to hub...") |
| trainer.save_model(OUTPUT_DIR) |
| trainer.push_to_hub() |
|
|
| print(f"Done! Model at https://huggingface.co/{HUB_MODEL_ID}") |
|
|