Increase grad_accum to 8, add gc.collect and empty_cache before training
Browse files
train.py
CHANGED
|
@@ -1,9 +1,7 @@
|
|
| 1 |
"""
|
| 2 |
-
Full fine-tuning script:
|
| 3 |
Model: google/gemma-2-2b-it
|
| 4 |
Dataset: talkmap/telecom-conversation-corpus
|
| 5 |
-
Converts turn-based telecom dialogues into alternating-role conversational messages for SFT.
|
| 6 |
-
Manually tokenizes with truncation to control sequence length and avoid OOM.
|
| 7 |
"""
|
| 8 |
import os
|
| 9 |
from collections import defaultdict
|
|
@@ -40,7 +38,6 @@ print("Loading dataset...")
|
|
| 40 |
ds = load_dataset(DATASET_ID, split="train")
|
| 41 |
print(f"Rows: {len(ds)}, Columns: {ds.column_names}")
|
| 42 |
|
| 43 |
-
# Group rows by conversation_id and sort by date_time
|
| 44 |
print("Grouping conversations...")
|
| 45 |
conversations = defaultdict(list)
|
| 46 |
for row in ds:
|
|
@@ -54,7 +51,6 @@ for row in ds:
|
|
| 54 |
for conv_id in conversations:
|
| 55 |
conversations[conv_id].sort(key=lambda x: x["date_time"])
|
| 56 |
|
| 57 |
-
# Convert each conversation into messages format with merged consecutive same-role turns
|
| 58 |
print("Converting to messages format...")
|
| 59 |
messages_data = []
|
| 60 |
for conv_id, turns in conversations.items():
|
|
@@ -107,7 +103,7 @@ if tokenizer.pad_token is None:
|
|
| 107 |
tokenizer.pad_token_id = tokenizer.eos_token_id
|
| 108 |
|
| 109 |
# ------------------------------------------------------------------
|
| 110 |
-
# Pre-tokenize dataset with truncation
|
| 111 |
# ------------------------------------------------------------------
|
| 112 |
print("Pre-tokenizing dataset...")
|
| 113 |
|
|
@@ -125,11 +121,10 @@ def apply_and_tokenize(example):
|
|
| 125 |
raw_dataset = Dataset.from_list(messages_data)
|
| 126 |
raw_dataset = raw_dataset.map(apply_and_tokenize, remove_columns=["messages"])
|
| 127 |
raw_dataset = raw_dataset.filter(lambda x: len(x["text"]) > 0)
|
| 128 |
-
|
| 129 |
print(f"Dataset after filtering: {len(raw_dataset)}")
|
| 130 |
|
| 131 |
# ------------------------------------------------------------------
|
| 132 |
-
# Model
|
| 133 |
# ------------------------------------------------------------------
|
| 134 |
print("Loading model...")
|
| 135 |
model = AutoModelForCausalLM.from_pretrained(
|
|
@@ -140,6 +135,11 @@ model = AutoModelForCausalLM.from_pretrained(
|
|
| 140 |
|
| 141 |
model.gradient_checkpointing_enable()
|
| 142 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
# ------------------------------------------------------------------
|
| 144 |
# Training arguments
|
| 145 |
# ------------------------------------------------------------------
|
|
@@ -149,7 +149,7 @@ args = SFTConfig(
|
|
| 149 |
push_to_hub=True,
|
| 150 |
num_train_epochs=1,
|
| 151 |
per_device_train_batch_size=1,
|
| 152 |
-
gradient_accumulation_steps=
|
| 153 |
learning_rate=2e-5,
|
| 154 |
max_seq_length=MAX_SEQ_LENGTH,
|
| 155 |
logging_strategy="steps",
|
|
|
|
| 1 |
"""
|
| 2 |
+
Full fine-tuning script with aggressive memory optimizations:
|
| 3 |
Model: google/gemma-2-2b-it
|
| 4 |
Dataset: talkmap/telecom-conversation-corpus
|
|
|
|
|
|
|
| 5 |
"""
|
| 6 |
import os
|
| 7 |
from collections import defaultdict
|
|
|
|
| 38 |
ds = load_dataset(DATASET_ID, split="train")
|
| 39 |
print(f"Rows: {len(ds)}, Columns: {ds.column_names}")
|
| 40 |
|
|
|
|
| 41 |
print("Grouping conversations...")
|
| 42 |
conversations = defaultdict(list)
|
| 43 |
for row in ds:
|
|
|
|
| 51 |
for conv_id in conversations:
|
| 52 |
conversations[conv_id].sort(key=lambda x: x["date_time"])
|
| 53 |
|
|
|
|
| 54 |
print("Converting to messages format...")
|
| 55 |
messages_data = []
|
| 56 |
for conv_id, turns in conversations.items():
|
|
|
|
| 103 |
tokenizer.pad_token_id = tokenizer.eos_token_id
|
| 104 |
|
| 105 |
# ------------------------------------------------------------------
|
| 106 |
+
# Pre-tokenize dataset with truncation
|
| 107 |
# ------------------------------------------------------------------
|
| 108 |
print("Pre-tokenizing dataset...")
|
| 109 |
|
|
|
|
| 121 |
raw_dataset = Dataset.from_list(messages_data)
|
| 122 |
raw_dataset = raw_dataset.map(apply_and_tokenize, remove_columns=["messages"])
|
| 123 |
raw_dataset = raw_dataset.filter(lambda x: len(x["text"]) > 0)
|
|
|
|
| 124 |
print(f"Dataset after filtering: {len(raw_dataset)}")
|
| 125 |
|
| 126 |
# ------------------------------------------------------------------
|
| 127 |
+
# Model - load on CPU first to control placement
|
| 128 |
# ------------------------------------------------------------------
|
| 129 |
print("Loading model...")
|
| 130 |
model = AutoModelForCausalLM.from_pretrained(
|
|
|
|
| 135 |
|
| 136 |
model.gradient_checkpointing_enable()
|
| 137 |
|
| 138 |
+
# Empty cache to free up fragmentation
|
| 139 |
+
import gc
|
| 140 |
+
gc.collect()
|
| 141 |
+
torch.cuda.empty_cache()
|
| 142 |
+
|
| 143 |
# ------------------------------------------------------------------
|
| 144 |
# Training arguments
|
| 145 |
# ------------------------------------------------------------------
|
|
|
|
| 149 |
push_to_hub=True,
|
| 150 |
num_train_epochs=1,
|
| 151 |
per_device_train_batch_size=1,
|
| 152 |
+
gradient_accumulation_steps=8,
|
| 153 |
learning_rate=2e-5,
|
| 154 |
max_seq_length=MAX_SEQ_LENGTH,
|
| 155 |
logging_strategy="steps",
|