| import os |
| import sys |
| from datasets import load_dataset, Dataset |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
| from trl import SFTTrainer, SFTConfig |
| import trackio |
|
|
| |
| |
| |
| MODEL_ID = "google/gemma-2b-it" |
| DATASET_ID = "talkmap/telecom-conversation-corpus" |
| OUTPUT_REPO = "ligaments-dev/gemma-2b-telecom-sft" |
| MAX_LENGTH = 2048 |
|
|
| |
| |
| |
| trackio.init(project="gemma-telecom-sft") |
|
|
| |
| |
| |
| print("Loading tokenizer...") |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| |
| |
| |
| print("Loading dataset...") |
| raw = load_dataset(DATASET_ID, split="train") |
|
|
| |
| |
| print("Grouping conversations...") |
| conv_map = {} |
| for ex in raw: |
| cid = ex["conversation_id"] |
| if cid not in conv_map: |
| conv_map[cid] = [] |
| role = "user" if ex["speaker"] == "client" else "assistant" |
| conv_map[cid].append({"role": role, "content": ex["text"]}) |
|
|
| |
| messages_list = [] |
| for cid, msgs in conv_map.items(): |
| |
| |
| messages_list.append({"messages": msgs}) |
|
|
| train_dataset = Dataset.from_list(messages_list) |
| print(f"Prepared {len(train_dataset)} conversation examples.") |
|
|
| |
| processed_path = "/tmp/telecom_processed" |
| train_dataset.save_to_disk(processed_path) |
| print(f"Saved processed dataset to {processed_path}") |
|
|
| |
| |
| |
| print("Loading model...") |
| model = AutoModelForCausalLM.from_pretrained( |
| MODEL_ID, |
| torch_dtype="auto", |
| device_map="auto", |
| ) |
|
|
| |
| |
| |
| training_args = SFTConfig( |
| output_dir="/tmp/gemma-telecom-sft", |
| num_train_epochs=3, |
| per_device_train_batch_size=2, |
| gradient_accumulation_steps=8, |
| learning_rate=2e-5, |
| warmup_ratio=0.1, |
| lr_scheduler_type="cosine", |
| max_length=MAX_LENGTH, |
| packing=True, |
| gradient_checkpointing=True, |
| bf16=True, |
| logging_steps=10, |
| save_strategy="epoch", |
| push_to_hub=True, |
| hub_model_id=OUTPUT_REPO, |
| hub_private_repo=False, |
| disable_tqdm=True, |
| logging_strategy="steps", |
| logging_first_step=True, |
| report_to=["trackio"], |
| ) |
|
|
| |
| |
| |
| print("Initializing SFTTrainer...") |
| trainer = SFTTrainer( |
| model=model, |
| args=training_args, |
| train_dataset=train_dataset, |
| processing_class=tokenizer, |
| ) |
|
|
| |
| |
| |
| print("Starting training...") |
| trainer.train() |
|
|
| |
| |
| |
| print("Pushing model to hub...") |
| trainer.push_to_hub(commit_message="Full SFT on telecom conversation corpus") |
| print("Done!") |
|
|