ligaments-dev commited on
Commit
31b65fb
·
verified ·
1 Parent(s): 745965e

Increase grad_accum to 8, add gc.collect and empty_cache before training

Browse files
Files changed (1) hide show
  1. train.py +9 -9
train.py CHANGED
@@ -1,9 +1,7 @@
1
  """
2
- Full fine-tuning script:
3
  Model: google/gemma-2-2b-it
4
  Dataset: talkmap/telecom-conversation-corpus
5
- Converts turn-based telecom dialogues into alternating-role conversational messages for SFT.
6
- Manually tokenizes with truncation to control sequence length and avoid OOM.
7
  """
8
  import os
9
  from collections import defaultdict
@@ -40,7 +38,6 @@ print("Loading dataset...")
40
  ds = load_dataset(DATASET_ID, split="train")
41
  print(f"Rows: {len(ds)}, Columns: {ds.column_names}")
42
 
43
- # Group rows by conversation_id and sort by date_time
44
  print("Grouping conversations...")
45
  conversations = defaultdict(list)
46
  for row in ds:
@@ -54,7 +51,6 @@ for row in ds:
54
  for conv_id in conversations:
55
  conversations[conv_id].sort(key=lambda x: x["date_time"])
56
 
57
- # Convert each conversation into messages format with merged consecutive same-role turns
58
  print("Converting to messages format...")
59
  messages_data = []
60
  for conv_id, turns in conversations.items():
@@ -107,7 +103,7 @@ if tokenizer.pad_token is None:
107
  tokenizer.pad_token_id = tokenizer.eos_token_id
108
 
109
  # ------------------------------------------------------------------
110
- # Pre-tokenize dataset with truncation (avoids SFTTrainer auto-tokenization)
111
  # ------------------------------------------------------------------
112
  print("Pre-tokenizing dataset...")
113
 
@@ -125,11 +121,10 @@ def apply_and_tokenize(example):
125
  raw_dataset = Dataset.from_list(messages_data)
126
  raw_dataset = raw_dataset.map(apply_and_tokenize, remove_columns=["messages"])
127
  raw_dataset = raw_dataset.filter(lambda x: len(x["text"]) > 0)
128
-
129
  print(f"Dataset after filtering: {len(raw_dataset)}")
130
 
131
  # ------------------------------------------------------------------
132
- # Model
133
  # ------------------------------------------------------------------
134
  print("Loading model...")
135
  model = AutoModelForCausalLM.from_pretrained(
@@ -140,6 +135,11 @@ model = AutoModelForCausalLM.from_pretrained(
140
 
141
  model.gradient_checkpointing_enable()
142
 
 
 
 
 
 
143
  # ------------------------------------------------------------------
144
  # Training arguments
145
  # ------------------------------------------------------------------
@@ -149,7 +149,7 @@ args = SFTConfig(
149
  push_to_hub=True,
150
  num_train_epochs=1,
151
  per_device_train_batch_size=1,
152
- gradient_accumulation_steps=4,
153
  learning_rate=2e-5,
154
  max_seq_length=MAX_SEQ_LENGTH,
155
  logging_strategy="steps",
 
1
  """
2
+ Full fine-tuning script with aggressive memory optimizations:
3
  Model: google/gemma-2-2b-it
4
  Dataset: talkmap/telecom-conversation-corpus
 
 
5
  """
6
  import os
7
  from collections import defaultdict
 
38
  ds = load_dataset(DATASET_ID, split="train")
39
  print(f"Rows: {len(ds)}, Columns: {ds.column_names}")
40
 
 
41
  print("Grouping conversations...")
42
  conversations = defaultdict(list)
43
  for row in ds:
 
51
  for conv_id in conversations:
52
  conversations[conv_id].sort(key=lambda x: x["date_time"])
53
 
 
54
  print("Converting to messages format...")
55
  messages_data = []
56
  for conv_id, turns in conversations.items():
 
103
  tokenizer.pad_token_id = tokenizer.eos_token_id
104
 
105
  # ------------------------------------------------------------------
106
+ # Pre-tokenize dataset with truncation
107
  # ------------------------------------------------------------------
108
  print("Pre-tokenizing dataset...")
109
 
 
121
  raw_dataset = Dataset.from_list(messages_data)
122
  raw_dataset = raw_dataset.map(apply_and_tokenize, remove_columns=["messages"])
123
  raw_dataset = raw_dataset.filter(lambda x: len(x["text"]) > 0)
 
124
  print(f"Dataset after filtering: {len(raw_dataset)}")
125
 
126
  # ------------------------------------------------------------------
127
+ # Model - load on CPU first to control placement
128
  # ------------------------------------------------------------------
129
  print("Loading model...")
130
  model = AutoModelForCausalLM.from_pretrained(
 
135
 
136
  model.gradient_checkpointing_enable()
137
 
138
+ # Empty cache to free up fragmentation
139
+ import gc
140
+ gc.collect()
141
+ torch.cuda.empty_cache()
142
+
143
  # ------------------------------------------------------------------
144
  # Training arguments
145
  # ------------------------------------------------------------------
 
149
  push_to_hub=True,
150
  num_train_epochs=1,
151
  per_device_train_batch_size=1,
152
+ gradient_accumulation_steps=8,
153
  learning_rate=2e-5,
154
  max_seq_length=MAX_SEQ_LENGTH,
155
  logging_strategy="steps",