ligaments-dev commited on
Commit
0ce3e14
·
verified ·
1 Parent(s): aa1a19d

Cap turns to 8, require assistant end, 5k conversations for OOM fix

Browse files
Files changed (1) hide show
  1. train.py +9 -2
train.py CHANGED
@@ -3,7 +3,6 @@ Full fine-tuning script:
3
  Model: google/gemma-2-2b-it
4
  Dataset: talkmap/telecom-conversation-corpus
5
  Converts turn-based telecom dialogues into alternating-role conversational messages for SFT.
6
- NOTE: Gemma chat template does NOT support system role.
7
  """
8
  import os
9
  from collections import defaultdict
@@ -21,7 +20,8 @@ MODEL_ID = "google/gemma-2-2b-it"
21
  DATASET_ID = "talkmap/telecom-conversation-corpus"
22
  OUTPUT_DIR = "./gemma-2b-it-telecom"
23
  HUB_MODEL_ID = "ligaments-dev/gemma-2b-it-telecom"
24
- MAX_CONVERSATIONS = 10000 # cap for faster training
 
25
 
26
  # ------------------------------------------------------------------
27
  # Trackio monitoring
@@ -57,6 +57,9 @@ for conv_id in conversations:
57
  print("Converting to messages format...")
58
  messages_data = []
59
  for conv_id, turns in conversations.items():
 
 
 
60
  messages = []
61
  current_role = None
62
  current_content = []
@@ -86,6 +89,10 @@ for conv_id, turns in conversations.items():
86
  if not valid:
87
  continue
88
 
 
 
 
 
89
  messages_data.append({"messages": messages})
90
 
91
  if len(messages_data) >= MAX_CONVERSATIONS:
 
3
  Model: google/gemma-2-2b-it
4
  Dataset: talkmap/telecom-conversation-corpus
5
  Converts turn-based telecom dialogues into alternating-role conversational messages for SFT.
 
6
  """
7
  import os
8
  from collections import defaultdict
 
20
  DATASET_ID = "talkmap/telecom-conversation-corpus"
21
  OUTPUT_DIR = "./gemma-2b-it-telecom"
22
  HUB_MODEL_ID = "ligaments-dev/gemma-2b-it-telecom"
23
+ MAX_CONVERSATIONS = 5000
24
+ MAX_TURNS = 8 # cap conversation length to reduce activation memory
25
 
26
  # ------------------------------------------------------------------
27
  # Trackio monitoring
 
57
  print("Converting to messages format...")
58
  messages_data = []
59
  for conv_id, turns in conversations.items():
60
+ # Cap turns to MAX_TURNS to keep sequences shorter
61
+ turns = turns[:MAX_TURNS]
62
+
63
  messages = []
64
  current_role = None
65
  current_content = []
 
89
  if not valid:
90
  continue
91
 
92
+ # Ensure conversation ends with assistant (complete pair)
93
+ if messages[-1]["role"] != "assistant":
94
+ continue
95
+
96
  messages_data.append({"messages": messages})
97
 
98
  if len(messages_data) >= MAX_CONVERSATIONS: