Reduce to 1 epoch, 10k conversations for faster training on larger GPU
Browse files
train.py
CHANGED
|
@@ -21,7 +21,7 @@ MODEL_ID = "google/gemma-2-2b-it"
|
|
| 21 |
DATASET_ID = "talkmap/telecom-conversation-corpus"
|
| 22 |
OUTPUT_DIR = "./gemma-2b-it-telecom"
|
| 23 |
HUB_MODEL_ID = "ligaments-dev/gemma-2b-it-telecom"
|
| 24 |
-
MAX_CONVERSATIONS =
|
| 25 |
|
| 26 |
# ------------------------------------------------------------------
|
| 27 |
# Trackio monitoring
|
|
@@ -94,7 +94,6 @@ for conv_id, turns in conversations.items():
|
|
| 94 |
train_dataset = Dataset.from_list(messages_data)
|
| 95 |
print(f"Total conversations: {len(train_dataset)}")
|
| 96 |
|
| 97 |
-
# Print a sample for debugging
|
| 98 |
if len(train_dataset) > 0:
|
| 99 |
print("Sample conversation:")
|
| 100 |
print(train_dataset[0])
|
|
@@ -127,7 +126,7 @@ args = SFTConfig(
|
|
| 127 |
output_dir=OUTPUT_DIR,
|
| 128 |
hub_model_id=HUB_MODEL_ID,
|
| 129 |
push_to_hub=True,
|
| 130 |
-
num_train_epochs=
|
| 131 |
per_device_train_batch_size=1,
|
| 132 |
gradient_accumulation_steps=4,
|
| 133 |
learning_rate=2e-5,
|
|
|
|
| 21 |
DATASET_ID = "talkmap/telecom-conversation-corpus"
|
| 22 |
OUTPUT_DIR = "./gemma-2b-it-telecom"
|
| 23 |
HUB_MODEL_ID = "ligaments-dev/gemma-2b-it-telecom"
|
| 24 |
+
MAX_CONVERSATIONS = 10000 # cap for faster training
|
| 25 |
|
| 26 |
# ------------------------------------------------------------------
|
| 27 |
# Trackio monitoring
|
|
|
|
| 94 |
train_dataset = Dataset.from_list(messages_data)
|
| 95 |
print(f"Total conversations: {len(train_dataset)}")
|
| 96 |
|
|
|
|
| 97 |
if len(train_dataset) > 0:
|
| 98 |
print("Sample conversation:")
|
| 99 |
print(train_dataset[0])
|
|
|
|
| 126 |
output_dir=OUTPUT_DIR,
|
| 127 |
hub_model_id=HUB_MODEL_ID,
|
| 128 |
push_to_hub=True,
|
| 129 |
+
num_train_epochs=1,
|
| 130 |
per_device_train_batch_size=1,
|
| 131 |
gradient_accumulation_steps=4,
|
| 132 |
learning_rate=2e-5,
|