ligaments-dev commited on
Commit
745965e
·
verified ·
1 Parent(s): 0ce3e14

Manually pre-tokenize with apply_chat_template, cap seq length to 512 for OOM fix

Browse files
Files changed (1) hide show
  1. train.py +29 -14
train.py CHANGED
@@ -3,6 +3,7 @@ Full fine-tuning script:
3
  Model: google/gemma-2-2b-it
4
  Dataset: talkmap/telecom-conversation-corpus
5
  Converts turn-based telecom dialogues into alternating-role conversational messages for SFT.
 
6
  """
7
  import os
8
  from collections import defaultdict
@@ -20,8 +21,9 @@ MODEL_ID = "google/gemma-2-2b-it"
20
  DATASET_ID = "talkmap/telecom-conversation-corpus"
21
  OUTPUT_DIR = "./gemma-2b-it-telecom"
22
  HUB_MODEL_ID = "ligaments-dev/gemma-2b-it-telecom"
23
- MAX_CONVERSATIONS = 5000
24
- MAX_TURNS = 8 # cap conversation length to reduce activation memory
 
25
 
26
  # ------------------------------------------------------------------
27
  # Trackio monitoring
@@ -53,11 +55,9 @@ for conv_id in conversations:
53
  conversations[conv_id].sort(key=lambda x: x["date_time"])
54
 
55
  # Convert each conversation into messages format with merged consecutive same-role turns
56
- # Gemma requires: NO system role, user first, alternating user/assistant
57
  print("Converting to messages format...")
58
  messages_data = []
59
  for conv_id, turns in conversations.items():
60
- # Cap turns to MAX_TURNS to keep sequences shorter
61
  turns = turns[:MAX_TURNS]
62
 
63
  messages = []
@@ -75,11 +75,9 @@ for conv_id, turns in conversations.items():
75
  if current_role is not None:
76
  messages.append({"role": current_role, "content": "\n".join(current_content)})
77
 
78
- # Gemma requires first turn to be user and alternating roles
79
  if not messages or messages[0]["role"] != "user":
80
  continue
81
 
82
- # Verify alternating roles
83
  valid = True
84
  for i, msg in enumerate(messages):
85
  expected_role = "user" if i % 2 == 0 else "assistant"
@@ -89,7 +87,6 @@ for conv_id, turns in conversations.items():
89
  if not valid:
90
  continue
91
 
92
- # Ensure conversation ends with assistant (complete pair)
93
  if messages[-1]["role"] != "assistant":
94
  continue
95
 
@@ -98,12 +95,7 @@ for conv_id, turns in conversations.items():
98
  if len(messages_data) >= MAX_CONVERSATIONS:
99
  break
100
 
101
- train_dataset = Dataset.from_list(messages_data)
102
- print(f"Total conversations: {len(train_dataset)}")
103
-
104
- if len(train_dataset) > 0:
105
- print("Sample conversation:")
106
- print(train_dataset[0])
107
 
108
  # ------------------------------------------------------------------
109
  # Tokenizer
@@ -114,6 +106,28 @@ if tokenizer.pad_token is None:
114
  tokenizer.pad_token = tokenizer.eos_token
115
  tokenizer.pad_token_id = tokenizer.eos_token_id
116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  # ------------------------------------------------------------------
118
  # Model
119
  # ------------------------------------------------------------------
@@ -137,6 +151,7 @@ args = SFTConfig(
137
  per_device_train_batch_size=1,
138
  gradient_accumulation_steps=4,
139
  learning_rate=2e-5,
 
140
  logging_strategy="steps",
141
  logging_steps=10,
142
  logging_first_step=True,
@@ -155,7 +170,7 @@ print("Initializing SFTTrainer...")
155
  trainer = SFTTrainer(
156
  model=model,
157
  args=args,
158
- train_dataset=train_dataset,
159
  processing_class=tokenizer,
160
  )
161
 
 
3
  Model: google/gemma-2-2b-it
4
  Dataset: talkmap/telecom-conversation-corpus
5
  Converts turn-based telecom dialogues into alternating-role conversational messages for SFT.
6
+ Manually tokenizes with truncation to control sequence length and avoid OOM.
7
  """
8
  import os
9
  from collections import defaultdict
 
21
  DATASET_ID = "talkmap/telecom-conversation-corpus"
22
  OUTPUT_DIR = "./gemma-2b-it-telecom"
23
  HUB_MODEL_ID = "ligaments-dev/gemma-2b-it-telecom"
24
+ MAX_CONVERSATIONS = 10000
25
+ MAX_TURNS = 6
26
+ MAX_SEQ_LENGTH = 512
27
 
28
  # ------------------------------------------------------------------
29
  # Trackio monitoring
 
55
  conversations[conv_id].sort(key=lambda x: x["date_time"])
56
 
57
  # Convert each conversation into messages format with merged consecutive same-role turns
 
58
  print("Converting to messages format...")
59
  messages_data = []
60
  for conv_id, turns in conversations.items():
 
61
  turns = turns[:MAX_TURNS]
62
 
63
  messages = []
 
75
  if current_role is not None:
76
  messages.append({"role": current_role, "content": "\n".join(current_content)})
77
 
 
78
  if not messages or messages[0]["role"] != "user":
79
  continue
80
 
 
81
  valid = True
82
  for i, msg in enumerate(messages):
83
  expected_role = "user" if i % 2 == 0 else "assistant"
 
87
  if not valid:
88
  continue
89
 
 
90
  if messages[-1]["role"] != "assistant":
91
  continue
92
 
 
95
  if len(messages_data) >= MAX_CONVERSATIONS:
96
  break
97
 
98
+ print(f"Total conversations: {len(messages_data)}")
 
 
 
 
 
99
 
100
  # ------------------------------------------------------------------
101
  # Tokenizer
 
106
  tokenizer.pad_token = tokenizer.eos_token
107
  tokenizer.pad_token_id = tokenizer.eos_token_id
108
 
109
+ # ------------------------------------------------------------------
110
+ # Pre-tokenize dataset with truncation (avoids SFTTrainer auto-tokenization)
111
+ # ------------------------------------------------------------------
112
+ print("Pre-tokenizing dataset...")
113
+
114
+ def apply_and_tokenize(example):
115
+ try:
116
+ text = tokenizer.apply_chat_template(
117
+ example["messages"],
118
+ tokenize=False,
119
+ add_generation_prompt=False,
120
+ )
121
+ except Exception:
122
+ text = ""
123
+ return {"text": text}
124
+
125
+ raw_dataset = Dataset.from_list(messages_data)
126
+ raw_dataset = raw_dataset.map(apply_and_tokenize, remove_columns=["messages"])
127
+ raw_dataset = raw_dataset.filter(lambda x: len(x["text"]) > 0)
128
+
129
+ print(f"Dataset after filtering: {len(raw_dataset)}")
130
+
131
  # ------------------------------------------------------------------
132
  # Model
133
  # ------------------------------------------------------------------
 
151
  per_device_train_batch_size=1,
152
  gradient_accumulation_steps=4,
153
  learning_rate=2e-5,
154
+ max_seq_length=MAX_SEQ_LENGTH,
155
  logging_strategy="steps",
156
  logging_steps=10,
157
  logging_first_step=True,
 
170
  trainer = SFTTrainer(
171
  model=model,
172
  args=args,
173
+ train_dataset=raw_dataset,
174
  processing_class=tokenizer,
175
  )
176