Cap turns to 8, require assistant end, 5k conversations for OOM fix
Browse files
train.py
CHANGED
|
@@ -3,7 +3,6 @@ Full fine-tuning script:
|
|
| 3 |
Model: google/gemma-2-2b-it
|
| 4 |
Dataset: talkmap/telecom-conversation-corpus
|
| 5 |
Converts turn-based telecom dialogues into alternating-role conversational messages for SFT.
|
| 6 |
-
NOTE: Gemma chat template does NOT support system role.
|
| 7 |
"""
|
| 8 |
import os
|
| 9 |
from collections import defaultdict
|
|
@@ -21,7 +20,8 @@ MODEL_ID = "google/gemma-2-2b-it"
|
|
| 21 |
DATASET_ID = "talkmap/telecom-conversation-corpus"
|
| 22 |
OUTPUT_DIR = "./gemma-2b-it-telecom"
|
| 23 |
HUB_MODEL_ID = "ligaments-dev/gemma-2b-it-telecom"
|
| 24 |
-
MAX_CONVERSATIONS =
|
|
|
|
| 25 |
|
| 26 |
# ------------------------------------------------------------------
|
| 27 |
# Trackio monitoring
|
|
@@ -57,6 +57,9 @@ for conv_id in conversations:
|
|
| 57 |
print("Converting to messages format...")
|
| 58 |
messages_data = []
|
| 59 |
for conv_id, turns in conversations.items():
|
|
|
|
|
|
|
|
|
|
| 60 |
messages = []
|
| 61 |
current_role = None
|
| 62 |
current_content = []
|
|
@@ -86,6 +89,10 @@ for conv_id, turns in conversations.items():
|
|
| 86 |
if not valid:
|
| 87 |
continue
|
| 88 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
messages_data.append({"messages": messages})
|
| 90 |
|
| 91 |
if len(messages_data) >= MAX_CONVERSATIONS:
|
|
|
|
| 3 |
Model: google/gemma-2-2b-it
|
| 4 |
Dataset: talkmap/telecom-conversation-corpus
|
| 5 |
Converts turn-based telecom dialogues into alternating-role conversational messages for SFT.
|
|
|
|
| 6 |
"""
|
| 7 |
import os
|
| 8 |
from collections import defaultdict
|
|
|
|
| 20 |
DATASET_ID = "talkmap/telecom-conversation-corpus"
|
| 21 |
OUTPUT_DIR = "./gemma-2b-it-telecom"
|
| 22 |
HUB_MODEL_ID = "ligaments-dev/gemma-2b-it-telecom"
|
| 23 |
+
MAX_CONVERSATIONS = 5000
|
| 24 |
+
MAX_TURNS = 8 # cap conversation length to reduce activation memory
|
| 25 |
|
| 26 |
# ------------------------------------------------------------------
|
| 27 |
# Trackio monitoring
|
|
|
|
| 57 |
print("Converting to messages format...")
|
| 58 |
messages_data = []
|
| 59 |
for conv_id, turns in conversations.items():
|
| 60 |
+
# Cap turns to MAX_TURNS to keep sequences shorter
|
| 61 |
+
turns = turns[:MAX_TURNS]
|
| 62 |
+
|
| 63 |
messages = []
|
| 64 |
current_role = None
|
| 65 |
current_content = []
|
|
|
|
| 89 |
if not valid:
|
| 90 |
continue
|
| 91 |
|
| 92 |
+
# Ensure conversation ends with assistant (complete pair)
|
| 93 |
+
if messages[-1]["role"] != "assistant":
|
| 94 |
+
continue
|
| 95 |
+
|
| 96 |
messages_data.append({"messages": messages})
|
| 97 |
|
| 98 |
if len(messages_data) >= MAX_CONVERSATIONS:
|