File size: 4,276 Bytes
4e9803a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import sys
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import SFTTrainer, SFTConfig
import trackio

# ---------------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------------
MODEL_ID = "google/gemma-2b-it"
DATASET_ID = "talkmap/telecom-conversation-corpus"
OUTPUT_REPO = "ligaments-dev/gemma-2b-telecom-sft"
MAX_LENGTH = 2048

# ---------------------------------------------------------------------------
# Logging / tracking
# ---------------------------------------------------------------------------
trackio.init(project="gemma-telecom-sft")

# ---------------------------------------------------------------------------
# Load tokenizer
# ---------------------------------------------------------------------------
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# ---------------------------------------------------------------------------
# Preprocess dataset: group rows by conversation_id into messages
# ---------------------------------------------------------------------------
print("Loading dataset...")
raw = load_dataset(DATASET_ID, split="train")

# The dataset has columns: conversation_id, speaker, date_time, text
# We group by conversation_id and build a single messages list per conversation.
print("Grouping conversations...")
conv_map = {}
for ex in raw:
    cid = ex["conversation_id"]
    if cid not in conv_map:
        conv_map[cid] = []
    role = "user" if ex["speaker"] == "client" else "assistant"
    conv_map[cid].append({"role": role, "content": ex["text"]})

# Build a huggingface Dataset from the grouped conversations
messages_list = []
for cid, msgs in conv_map.items():
    # Keep only alternating user/assistant; skip system-like turns if any.
    # Gemma-2b-it chat template expects user/assistant.
    messages_list.append({"messages": msgs})

train_dataset = Dataset.from_list(messages_list)
print(f"Prepared {len(train_dataset)} conversation examples.")

# Save processed dataset so we don't recompute on resume
processed_path = "/tmp/telecom_processed"
train_dataset.save_to_disk(processed_path)
print(f"Saved processed dataset to {processed_path}")

# ---------------------------------------------------------------------------
# Load model
# ---------------------------------------------------------------------------
print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype="auto",
    device_map="auto",
)

# ---------------------------------------------------------------------------
# Training arguments
# ---------------------------------------------------------------------------
training_args = SFTConfig(
    output_dir="/tmp/gemma-telecom-sft",
    num_train_epochs=3,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=8,
    learning_rate=2e-5,
    warmup_ratio=0.1,
    lr_scheduler_type="cosine",
    max_length=MAX_LENGTH,
    packing=True,
    gradient_checkpointing=True,
    bf16=True,
    logging_steps=10,
    save_strategy="epoch",
    push_to_hub=True,
    hub_model_id=OUTPUT_REPO,
    hub_private_repo=False,
    disable_tqdm=True,
    logging_strategy="steps",
    logging_first_step=True,
    report_to=["trackio"],
)

# ---------------------------------------------------------------------------
# Trainer
# ---------------------------------------------------------------------------
print("Initializing SFTTrainer...")
trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    processing_class=tokenizer,
)

# ---------------------------------------------------------------------------
# Train
# ---------------------------------------------------------------------------
print("Starting training...")
trainer.train()

# ---------------------------------------------------------------------------
# Push to hub
# ---------------------------------------------------------------------------
print("Pushing model to hub...")
trainer.push_to_hub(commit_message="Full SFT on telecom conversation corpus")
print("Done!")