ligaments-dev commited on
Commit
48301ee
·
verified ·
1 Parent(s): 8c55461

Remove system role, verify alternating user/assistant for Gemma compatibility

Browse files
Files changed (1) hide show
  1. train.py +19 -9
train.py CHANGED
@@ -3,6 +3,7 @@ Full fine-tuning script:
3
  Model: google/gemma-2-2b-it
4
  Dataset: talkmap/telecom-conversation-corpus
5
  Converts turn-based telecom dialogues into alternating-role conversational messages for SFT.
 
6
  """
7
  import os
8
  from collections import defaultdict
@@ -20,7 +21,7 @@ MODEL_ID = "google/gemma-2-2b-it"
20
  DATASET_ID = "talkmap/telecom-conversation-corpus"
21
  OUTPUT_DIR = "./gemma-2b-it-telecom"
22
  HUB_MODEL_ID = "ligaments-dev/gemma-2b-it-telecom"
23
- MAX_CONVERSATIONS = 50000 # cap for faster training
24
 
25
  # ------------------------------------------------------------------
26
  # Trackio monitoring
@@ -52,7 +53,7 @@ for conv_id in conversations:
52
  conversations[conv_id].sort(key=lambda x: x["date_time"])
53
 
54
  # Convert each conversation into messages format with merged consecutive same-role turns
55
- # Also add a system prompt and ensure first turn is always user
56
  print("Converting to messages format...")
57
  messages_data = []
58
  for conv_id, turns in conversations.items():
@@ -71,15 +72,19 @@ for conv_id, turns in conversations.items():
71
  if current_role is not None:
72
  messages.append({"role": current_role, "content": "\n".join(current_content)})
73
 
74
- # Skip if first role is not user (Gemma requires user first)
75
- if messages and messages[0]["role"] != "user":
76
  continue
77
 
78
- # Add system prompt for telecom context
79
- messages.insert(0, {
80
- "role": "system",
81
- "content": "You are a helpful telecom customer service assistant. Help customers with their mobile, internet, and billing issues."
82
- })
 
 
 
 
83
 
84
  messages_data.append({"messages": messages})
85
 
@@ -89,6 +94,11 @@ for conv_id, turns in conversations.items():
89
  train_dataset = Dataset.from_list(messages_data)
90
  print(f"Total conversations: {len(train_dataset)}")
91
 
 
 
 
 
 
92
  # ------------------------------------------------------------------
93
  # Tokenizer
94
  # ------------------------------------------------------------------
 
3
  Model: google/gemma-2-2b-it
4
  Dataset: talkmap/telecom-conversation-corpus
5
  Converts turn-based telecom dialogues into alternating-role conversational messages for SFT.
6
+ NOTE: Gemma chat template does NOT support system role.
7
  """
8
  import os
9
  from collections import defaultdict
 
21
  DATASET_ID = "talkmap/telecom-conversation-corpus"
22
  OUTPUT_DIR = "./gemma-2b-it-telecom"
23
  HUB_MODEL_ID = "ligaments-dev/gemma-2b-it-telecom"
24
+ MAX_CONVERSATIONS = 50000
25
 
26
  # ------------------------------------------------------------------
27
  # Trackio monitoring
 
53
  conversations[conv_id].sort(key=lambda x: x["date_time"])
54
 
55
  # Convert each conversation into messages format with merged consecutive same-role turns
56
+ # Gemma requires: NO system role, user first, alternating user/assistant
57
  print("Converting to messages format...")
58
  messages_data = []
59
  for conv_id, turns in conversations.items():
 
72
  if current_role is not None:
73
  messages.append({"role": current_role, "content": "\n".join(current_content)})
74
 
75
+ # Gemma requires first turn to be user and alternating roles
76
+ if not messages or messages[0]["role"] != "user":
77
  continue
78
 
79
+ # Verify alternating roles
80
+ valid = True
81
+ for i, msg in enumerate(messages):
82
+ expected_role = "user" if i % 2 == 0 else "assistant"
83
+ if msg["role"] != expected_role:
84
+ valid = False
85
+ break
86
+ if not valid:
87
+ continue
88
 
89
  messages_data.append({"messages": messages})
90
 
 
94
  train_dataset = Dataset.from_list(messages_data)
95
  print(f"Total conversations: {len(train_dataset)}")
96
 
97
+ # Print a sample for debugging
98
+ if len(train_dataset) > 0:
99
+ print("Sample conversation:")
100
+ print(train_dataset[0])
101
+
102
  # ------------------------------------------------------------------
103
  # Tokenizer
104
  # ------------------------------------------------------------------