ligaments-dev commited on
Commit
aa1a19d
·
verified ·
1 Parent(s): 48301ee

Reduce to 1 epoch, 10k conversations for faster training on larger GPU

Browse files
Files changed (1) hide show
  1. train.py +2 -3
train.py CHANGED
@@ -21,7 +21,7 @@ MODEL_ID = "google/gemma-2-2b-it"
21
  DATASET_ID = "talkmap/telecom-conversation-corpus"
22
  OUTPUT_DIR = "./gemma-2b-it-telecom"
23
  HUB_MODEL_ID = "ligaments-dev/gemma-2b-it-telecom"
24
- MAX_CONVERSATIONS = 50000
25
 
26
  # ------------------------------------------------------------------
27
  # Trackio monitoring
@@ -94,7 +94,6 @@ for conv_id, turns in conversations.items():
94
  train_dataset = Dataset.from_list(messages_data)
95
  print(f"Total conversations: {len(train_dataset)}")
96
 
97
- # Print a sample for debugging
98
  if len(train_dataset) > 0:
99
  print("Sample conversation:")
100
  print(train_dataset[0])
@@ -127,7 +126,7 @@ args = SFTConfig(
127
  output_dir=OUTPUT_DIR,
128
  hub_model_id=HUB_MODEL_ID,
129
  push_to_hub=True,
130
- num_train_epochs=3,
131
  per_device_train_batch_size=1,
132
  gradient_accumulation_steps=4,
133
  learning_rate=2e-5,
 
21
  DATASET_ID = "talkmap/telecom-conversation-corpus"
22
  OUTPUT_DIR = "./gemma-2b-it-telecom"
23
  HUB_MODEL_ID = "ligaments-dev/gemma-2b-it-telecom"
24
+ MAX_CONVERSATIONS = 10000 # cap for faster training
25
 
26
  # ------------------------------------------------------------------
27
  # Trackio monitoring
 
94
  train_dataset = Dataset.from_list(messages_data)
95
  print(f"Total conversations: {len(train_dataset)}")
96
 
 
97
  if len(train_dataset) > 0:
98
  print("Sample conversation:")
99
  print(train_dataset[0])
 
126
  output_dir=OUTPUT_DIR,
127
  hub_model_id=HUB_MODEL_ID,
128
  push_to_hub=True,
129
+ num_train_epochs=1,
130
  per_device_train_batch_size=1,
131
  gradient_accumulation_steps=4,
132
  learning_rate=2e-5,