File size: 5,926 Bytes
1642b2a
31b65fb
1642b2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
745965e
 
 
1642b2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4778444
 
 
 
 
 
1642b2a
 
 
 
 
 
 
0ce3e14
 
1642b2a
3686614
 
1642b2a
 
3686614
 
 
 
 
 
 
 
 
8c55461
48301ee
8c55461
 
48301ee
 
 
 
 
 
 
 
8c55461
0ce3e14
 
 
1642b2a
 
8c55461
 
 
745965e
48301ee
1642b2a
 
 
 
cb6f87d
1642b2a
 
 
 
745965e
31b65fb
745965e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1642b2a
31b65fb
1642b2a
 
 
 
938e32b
1642b2a
 
 
 
 
31b65fb
 
 
 
 
1642b2a
 
 
 
 
 
 
aa1a19d
1642b2a
31b65fb
1642b2a
745965e
1642b2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
745965e
1642b2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
"""
Full fine-tuning script with aggressive memory optimizations:
  Model: google/gemma-2-2b-it
  Dataset: talkmap/telecom-conversation-corpus
"""
import os
from collections import defaultdict

from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import SFTTrainer, SFTConfig
import trackio
import torch

# ------------------------------------------------------------------
# Config
# ------------------------------------------------------------------
MODEL_ID = "google/gemma-2-2b-it"
DATASET_ID = "talkmap/telecom-conversation-corpus"
OUTPUT_DIR = "./gemma-2b-it-telecom"
HUB_MODEL_ID = "ligaments-dev/gemma-2b-it-telecom"
MAX_CONVERSATIONS = 10000
MAX_TURNS = 6
MAX_SEQ_LENGTH = 512

# ------------------------------------------------------------------
# Trackio monitoring
# ------------------------------------------------------------------
trackio.init(
    project="gemma-telecom-finetune",
    name="gemma-2b-it-full-sft",
)

# ------------------------------------------------------------------
# Load dataset
# ------------------------------------------------------------------
print("Loading dataset...")
ds = load_dataset(DATASET_ID, split="train")
print(f"Rows: {len(ds)}, Columns: {ds.column_names}")

print("Grouping conversations...")
conversations = defaultdict(list)
for row in ds:
    text = row["text"] if row["text"] is not None else ""
    conversations[row["conversation_id"]].append({
        "speaker": row["speaker"],
        "date_time": row["date_time"],
        "text": text,
    })

for conv_id in conversations:
    conversations[conv_id].sort(key=lambda x: x["date_time"])

print("Converting to messages format...")
messages_data = []
for conv_id, turns in conversations.items():
    turns = turns[:MAX_TURNS]

    messages = []
    current_role = None
    current_content = []
    for turn in turns:
        role = "user" if turn["speaker"] == "client" else "assistant"
        if role == current_role:
            current_content.append(turn["text"])
        else:
            if current_role is not None:
                messages.append({"role": current_role, "content": "\n".join(current_content)})
            current_role = role
            current_content = [turn["text"]]
    if current_role is not None:
        messages.append({"role": current_role, "content": "\n".join(current_content)})

    if not messages or messages[0]["role"] != "user":
        continue

    valid = True
    for i, msg in enumerate(messages):
        expected_role = "user" if i % 2 == 0 else "assistant"
        if msg["role"] != expected_role:
            valid = False
            break
    if not valid:
        continue

    if messages[-1]["role"] != "assistant":
        continue

    messages_data.append({"messages": messages})

    if len(messages_data) >= MAX_CONVERSATIONS:
        break

print(f"Total conversations: {len(messages_data)}")

# ------------------------------------------------------------------
# Tokenizer
# ------------------------------------------------------------------
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id

# ------------------------------------------------------------------
# Pre-tokenize dataset with truncation
# ------------------------------------------------------------------
print("Pre-tokenizing dataset...")

def apply_and_tokenize(example):
    try:
        text = tokenizer.apply_chat_template(
            example["messages"],
            tokenize=False,
            add_generation_prompt=False,
        )
    except Exception:
        text = ""
    return {"text": text}

raw_dataset = Dataset.from_list(messages_data)
raw_dataset = raw_dataset.map(apply_and_tokenize, remove_columns=["messages"])
raw_dataset = raw_dataset.filter(lambda x: len(x["text"]) > 0)
print(f"Dataset after filtering: {len(raw_dataset)}")

# ------------------------------------------------------------------
# Model - load on CPU first to control placement
# ------------------------------------------------------------------
print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    dtype=torch.bfloat16,
    device_map="auto",
)

model.gradient_checkpointing_enable()

# Empty cache to free up fragmentation
import gc
gc.collect()
torch.cuda.empty_cache()

# ------------------------------------------------------------------
# Training arguments
# ------------------------------------------------------------------
args = SFTConfig(
    output_dir=OUTPUT_DIR,
    hub_model_id=HUB_MODEL_ID,
    push_to_hub=True,
    num_train_epochs=1,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=8,
    learning_rate=2e-5,
    max_seq_length=MAX_SEQ_LENGTH,
    logging_strategy="steps",
    logging_steps=10,
    logging_first_step=True,
    disable_tqdm=True,
    save_strategy="epoch",
    bf16=True,
    gradient_checkpointing=True,
    report_to=["trackio"],
    remove_unused_columns=False,
)

# ------------------------------------------------------------------
# Trainer
# ------------------------------------------------------------------
print("Initializing SFTTrainer...")
trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=raw_dataset,
    processing_class=tokenizer,
)

# ------------------------------------------------------------------
# Train
# ------------------------------------------------------------------
print("Starting training...")
trainer.train()

# ------------------------------------------------------------------
# Save & Push
# ------------------------------------------------------------------
print("Saving and pushing to hub...")
trainer.save_model(OUTPUT_DIR)
trainer.push_to_hub()

print(f"Done! Model at https://huggingface.co/{HUB_MODEL_ID}")