Manually pre-tokenize with apply_chat_template, cap seq length to 512 for OOM fix
Browse files
train.py
CHANGED
|
@@ -3,6 +3,7 @@ Full fine-tuning script:
|
|
| 3 |
Model: google/gemma-2-2b-it
|
| 4 |
Dataset: talkmap/telecom-conversation-corpus
|
| 5 |
Converts turn-based telecom dialogues into alternating-role conversational messages for SFT.
|
|
|
|
| 6 |
"""
|
| 7 |
import os
|
| 8 |
from collections import defaultdict
|
|
@@ -20,8 +21,9 @@ MODEL_ID = "google/gemma-2-2b-it"
|
|
| 20 |
DATASET_ID = "talkmap/telecom-conversation-corpus"
|
| 21 |
OUTPUT_DIR = "./gemma-2b-it-telecom"
|
| 22 |
HUB_MODEL_ID = "ligaments-dev/gemma-2b-it-telecom"
|
| 23 |
-
MAX_CONVERSATIONS =
|
| 24 |
-
MAX_TURNS =
|
|
|
|
| 25 |
|
| 26 |
# ------------------------------------------------------------------
|
| 27 |
# Trackio monitoring
|
|
@@ -53,11 +55,9 @@ for conv_id in conversations:
|
|
| 53 |
conversations[conv_id].sort(key=lambda x: x["date_time"])
|
| 54 |
|
| 55 |
# Convert each conversation into messages format with merged consecutive same-role turns
|
| 56 |
-
# Gemma requires: NO system role, user first, alternating user/assistant
|
| 57 |
print("Converting to messages format...")
|
| 58 |
messages_data = []
|
| 59 |
for conv_id, turns in conversations.items():
|
| 60 |
-
# Cap turns to MAX_TURNS to keep sequences shorter
|
| 61 |
turns = turns[:MAX_TURNS]
|
| 62 |
|
| 63 |
messages = []
|
|
@@ -75,11 +75,9 @@ for conv_id, turns in conversations.items():
|
|
| 75 |
if current_role is not None:
|
| 76 |
messages.append({"role": current_role, "content": "\n".join(current_content)})
|
| 77 |
|
| 78 |
-
# Gemma requires first turn to be user and alternating roles
|
| 79 |
if not messages or messages[0]["role"] != "user":
|
| 80 |
continue
|
| 81 |
|
| 82 |
-
# Verify alternating roles
|
| 83 |
valid = True
|
| 84 |
for i, msg in enumerate(messages):
|
| 85 |
expected_role = "user" if i % 2 == 0 else "assistant"
|
|
@@ -89,7 +87,6 @@ for conv_id, turns in conversations.items():
|
|
| 89 |
if not valid:
|
| 90 |
continue
|
| 91 |
|
| 92 |
-
# Ensure conversation ends with assistant (complete pair)
|
| 93 |
if messages[-1]["role"] != "assistant":
|
| 94 |
continue
|
| 95 |
|
|
@@ -98,12 +95,7 @@ for conv_id, turns in conversations.items():
|
|
| 98 |
if len(messages_data) >= MAX_CONVERSATIONS:
|
| 99 |
break
|
| 100 |
|
| 101 |
-
|
| 102 |
-
print(f"Total conversations: {len(train_dataset)}")
|
| 103 |
-
|
| 104 |
-
if len(train_dataset) > 0:
|
| 105 |
-
print("Sample conversation:")
|
| 106 |
-
print(train_dataset[0])
|
| 107 |
|
| 108 |
# ------------------------------------------------------------------
|
| 109 |
# Tokenizer
|
|
@@ -114,6 +106,28 @@ if tokenizer.pad_token is None:
|
|
| 114 |
tokenizer.pad_token = tokenizer.eos_token
|
| 115 |
tokenizer.pad_token_id = tokenizer.eos_token_id
|
| 116 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
# ------------------------------------------------------------------
|
| 118 |
# Model
|
| 119 |
# ------------------------------------------------------------------
|
|
@@ -137,6 +151,7 @@ args = SFTConfig(
|
|
| 137 |
per_device_train_batch_size=1,
|
| 138 |
gradient_accumulation_steps=4,
|
| 139 |
learning_rate=2e-5,
|
|
|
|
| 140 |
logging_strategy="steps",
|
| 141 |
logging_steps=10,
|
| 142 |
logging_first_step=True,
|
|
@@ -155,7 +170,7 @@ print("Initializing SFTTrainer...")
|
|
| 155 |
trainer = SFTTrainer(
|
| 156 |
model=model,
|
| 157 |
args=args,
|
| 158 |
-
train_dataset=
|
| 159 |
processing_class=tokenizer,
|
| 160 |
)
|
| 161 |
|
|
|
|
| 3 |
Model: google/gemma-2-2b-it
|
| 4 |
Dataset: talkmap/telecom-conversation-corpus
|
| 5 |
Converts turn-based telecom dialogues into alternating-role conversational messages for SFT.
|
| 6 |
+
Manually tokenizes with truncation to control sequence length and avoid OOM.
|
| 7 |
"""
|
| 8 |
import os
|
| 9 |
from collections import defaultdict
|
|
|
|
| 21 |
DATASET_ID = "talkmap/telecom-conversation-corpus"
|
| 22 |
OUTPUT_DIR = "./gemma-2b-it-telecom"
|
| 23 |
HUB_MODEL_ID = "ligaments-dev/gemma-2b-it-telecom"
|
| 24 |
+
MAX_CONVERSATIONS = 10000
|
| 25 |
+
MAX_TURNS = 6
|
| 26 |
+
MAX_SEQ_LENGTH = 512
|
| 27 |
|
| 28 |
# ------------------------------------------------------------------
|
| 29 |
# Trackio monitoring
|
|
|
|
| 55 |
conversations[conv_id].sort(key=lambda x: x["date_time"])
|
| 56 |
|
| 57 |
# Convert each conversation into messages format with merged consecutive same-role turns
|
|
|
|
| 58 |
print("Converting to messages format...")
|
| 59 |
messages_data = []
|
| 60 |
for conv_id, turns in conversations.items():
|
|
|
|
| 61 |
turns = turns[:MAX_TURNS]
|
| 62 |
|
| 63 |
messages = []
|
|
|
|
| 75 |
if current_role is not None:
|
| 76 |
messages.append({"role": current_role, "content": "\n".join(current_content)})
|
| 77 |
|
|
|
|
| 78 |
if not messages or messages[0]["role"] != "user":
|
| 79 |
continue
|
| 80 |
|
|
|
|
| 81 |
valid = True
|
| 82 |
for i, msg in enumerate(messages):
|
| 83 |
expected_role = "user" if i % 2 == 0 else "assistant"
|
|
|
|
| 87 |
if not valid:
|
| 88 |
continue
|
| 89 |
|
|
|
|
| 90 |
if messages[-1]["role"] != "assistant":
|
| 91 |
continue
|
| 92 |
|
|
|
|
| 95 |
if len(messages_data) >= MAX_CONVERSATIONS:
|
| 96 |
break
|
| 97 |
|
| 98 |
+
print(f"Total conversations: {len(messages_data)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
# ------------------------------------------------------------------
|
| 101 |
# Tokenizer
|
|
|
|
| 106 |
tokenizer.pad_token = tokenizer.eos_token
|
| 107 |
tokenizer.pad_token_id = tokenizer.eos_token_id
|
| 108 |
|
| 109 |
+
# ------------------------------------------------------------------
|
| 110 |
+
# Pre-tokenize dataset with truncation (avoids SFTTrainer auto-tokenization)
|
| 111 |
+
# ------------------------------------------------------------------
|
| 112 |
+
print("Pre-tokenizing dataset...")
|
| 113 |
+
|
| 114 |
+
def apply_and_tokenize(example):
|
| 115 |
+
try:
|
| 116 |
+
text = tokenizer.apply_chat_template(
|
| 117 |
+
example["messages"],
|
| 118 |
+
tokenize=False,
|
| 119 |
+
add_generation_prompt=False,
|
| 120 |
+
)
|
| 121 |
+
except Exception:
|
| 122 |
+
text = ""
|
| 123 |
+
return {"text": text}
|
| 124 |
+
|
| 125 |
+
raw_dataset = Dataset.from_list(messages_data)
|
| 126 |
+
raw_dataset = raw_dataset.map(apply_and_tokenize, remove_columns=["messages"])
|
| 127 |
+
raw_dataset = raw_dataset.filter(lambda x: len(x["text"]) > 0)
|
| 128 |
+
|
| 129 |
+
print(f"Dataset after filtering: {len(raw_dataset)}")
|
| 130 |
+
|
| 131 |
# ------------------------------------------------------------------
|
| 132 |
# Model
|
| 133 |
# ------------------------------------------------------------------
|
|
|
|
| 151 |
per_device_train_batch_size=1,
|
| 152 |
gradient_accumulation_steps=4,
|
| 153 |
learning_rate=2e-5,
|
| 154 |
+
max_seq_length=MAX_SEQ_LENGTH,
|
| 155 |
logging_strategy="steps",
|
| 156 |
logging_steps=10,
|
| 157 |
logging_first_step=True,
|
|
|
|
| 170 |
trainer = SFTTrainer(
|
| 171 |
model=model,
|
| 172 |
args=args,
|
| 173 |
+
train_dataset=raw_dataset,
|
| 174 |
processing_class=tokenizer,
|
| 175 |
)
|
| 176 |
|