Remove system role, verify alternating user/assistant for Gemma compatibility
Browse files
train.py
CHANGED
|
@@ -3,6 +3,7 @@ Full fine-tuning script:
|
|
| 3 |
Model: google/gemma-2-2b-it
|
| 4 |
Dataset: talkmap/telecom-conversation-corpus
|
| 5 |
Converts turn-based telecom dialogues into alternating-role conversational messages for SFT.
|
|
|
|
| 6 |
"""
|
| 7 |
import os
|
| 8 |
from collections import defaultdict
|
|
@@ -20,7 +21,7 @@ MODEL_ID = "google/gemma-2-2b-it"
|
|
| 20 |
DATASET_ID = "talkmap/telecom-conversation-corpus"
|
| 21 |
OUTPUT_DIR = "./gemma-2b-it-telecom"
|
| 22 |
HUB_MODEL_ID = "ligaments-dev/gemma-2b-it-telecom"
|
| 23 |
-
MAX_CONVERSATIONS = 50000
|
| 24 |
|
| 25 |
# ------------------------------------------------------------------
|
| 26 |
# Trackio monitoring
|
|
@@ -52,7 +53,7 @@ for conv_id in conversations:
|
|
| 52 |
conversations[conv_id].sort(key=lambda x: x["date_time"])
|
| 53 |
|
| 54 |
# Convert each conversation into messages format with merged consecutive same-role turns
|
| 55 |
-
#
|
| 56 |
print("Converting to messages format...")
|
| 57 |
messages_data = []
|
| 58 |
for conv_id, turns in conversations.items():
|
|
@@ -71,15 +72,19 @@ for conv_id, turns in conversations.items():
|
|
| 71 |
if current_role is not None:
|
| 72 |
messages.append({"role": current_role, "content": "\n".join(current_content)})
|
| 73 |
|
| 74 |
-
#
|
| 75 |
-
if messages
|
| 76 |
continue
|
| 77 |
|
| 78 |
-
#
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
"
|
| 82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
messages_data.append({"messages": messages})
|
| 85 |
|
|
@@ -89,6 +94,11 @@ for conv_id, turns in conversations.items():
|
|
| 89 |
train_dataset = Dataset.from_list(messages_data)
|
| 90 |
print(f"Total conversations: {len(train_dataset)}")
|
| 91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
# ------------------------------------------------------------------
|
| 93 |
# Tokenizer
|
| 94 |
# ------------------------------------------------------------------
|
|
|
|
| 3 |
Model: google/gemma-2-2b-it
|
| 4 |
Dataset: talkmap/telecom-conversation-corpus
|
| 5 |
Converts turn-based telecom dialogues into alternating-role conversational messages for SFT.
|
| 6 |
+
NOTE: Gemma chat template does NOT support system role.
|
| 7 |
"""
|
| 8 |
import os
|
| 9 |
from collections import defaultdict
|
|
|
|
| 21 |
DATASET_ID = "talkmap/telecom-conversation-corpus"
|
| 22 |
OUTPUT_DIR = "./gemma-2b-it-telecom"
|
| 23 |
HUB_MODEL_ID = "ligaments-dev/gemma-2b-it-telecom"
|
| 24 |
+
MAX_CONVERSATIONS = 50000
|
| 25 |
|
| 26 |
# ------------------------------------------------------------------
|
| 27 |
# Trackio monitoring
|
|
|
|
| 53 |
conversations[conv_id].sort(key=lambda x: x["date_time"])
|
| 54 |
|
| 55 |
# Convert each conversation into messages format with merged consecutive same-role turns
|
| 56 |
+
# Gemma requires: NO system role, user first, alternating user/assistant
|
| 57 |
print("Converting to messages format...")
|
| 58 |
messages_data = []
|
| 59 |
for conv_id, turns in conversations.items():
|
|
|
|
| 72 |
if current_role is not None:
|
| 73 |
messages.append({"role": current_role, "content": "\n".join(current_content)})
|
| 74 |
|
| 75 |
+
# Gemma requires first turn to be user and alternating roles
|
| 76 |
+
if not messages or messages[0]["role"] != "user":
|
| 77 |
continue
|
| 78 |
|
| 79 |
+
# Verify alternating roles
|
| 80 |
+
valid = True
|
| 81 |
+
for i, msg in enumerate(messages):
|
| 82 |
+
expected_role = "user" if i % 2 == 0 else "assistant"
|
| 83 |
+
if msg["role"] != expected_role:
|
| 84 |
+
valid = False
|
| 85 |
+
break
|
| 86 |
+
if not valid:
|
| 87 |
+
continue
|
| 88 |
|
| 89 |
messages_data.append({"messages": messages})
|
| 90 |
|
|
|
|
| 94 |
train_dataset = Dataset.from_list(messages_data)
|
| 95 |
print(f"Total conversations: {len(train_dataset)}")
|
| 96 |
|
| 97 |
+
# Print a sample for debugging
|
| 98 |
+
if len(train_dataset) > 0:
|
| 99 |
+
print("Sample conversation:")
|
| 100 |
+
print(train_dataset[0])
|
| 101 |
+
|
| 102 |
# ------------------------------------------------------------------
|
| 103 |
# Tokenizer
|
| 104 |
# ------------------------------------------------------------------
|