kgrabko commited on
Commit
968b1ff
·
verified ·
1 Parent(s): cf8d4e9

Upload fine_tune_jit_with_validation_1b.py

Browse files
Files changed (1) hide show
  1. fine_tune_jit_with_validation_1b.py +48 -53
fine_tune_jit_with_validation_1b.py CHANGED
@@ -1,14 +1,8 @@
1
- # Copyright (c) 2025 CMS Manhattan
2
- # All rights reserved.
3
- #
4
- # This file is part of a project authored by CMS Manhattan. You may use, distribute, and modify
5
- # this code under the terms of the APACHE 2.0 license.
6
-
7
  import os
8
  import torch
9
  import torch.nn as nn
10
  import torch.optim as optim
11
- from torch.utils.data import Dataset, DataLoader
12
  from transformers import GPT2TokenizerFast
13
  from tqdm import tqdm
14
  import shutil
@@ -66,55 +60,53 @@ DATASET_PATH = CLEAN_PATH
66
  OUTPUT_DIR = Path("build/fine_tuning_output")
67
  MODEL_SAVE_NAME = "gpt_finetuned.script.pt"
68
 
69
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
70
  print(f"Using device: {device}")
71
 
72
- # ============================= DATASET =============================
73
 
74
- class TextDataset(Dataset):
 
75
  def __init__(self, text_file, seq_len=TRAIN_SEQ_LEN, tokenizer_name="gpt2", split_type='train', val_ratio=VAL_SPLIT_RATIO):
76
  self.seq_len = seq_len
77
  self.tokenizer = GPT2TokenizerFast.from_pretrained(tokenizer_name)
78
  self.tokenizer.pad_token = self.tokenizer.eos_token
 
79
  self.split_type = split_type
80
-
81
- print(f"Loading text from {text_file} for {split_type} split...")
82
- text = Path(text_file).read_text(encoding="utf-8")
83
- tokens = self.tokenizer.encode(text)
84
-
85
- if len(tokens) < seq_len * 2:
86
- raise ValueError("Text too short!")
87
-
88
- all_inputs = []
89
- all_labels = []
90
-
91
- for i in range(0, len(tokens) - seq_len, seq_len):
92
- all_inputs.append(tokens[i:i + seq_len])
93
- all_labels.append(tokens[i + 1:i + seq_len + 1])
94
-
95
- total_sequences = len(all_inputs)
96
- val_size = int(total_sequences * val_ratio)
97
- train_size = total_sequences - val_size
98
-
99
- if self.split_type == 'train':
100
- self.inputs = all_inputs[:train_size]
101
- self.labels = all_labels[:train_size]
102
- elif self.split_type == 'val':
103
- self.inputs = all_inputs[train_size:]
104
- self.labels = all_labels[train_size:]
105
  else:
106
- raise ValueError("Invalid split_type. Must be 'train' or 'val'.")
107
-
108
- print(f"Created {len(self.inputs):,} sequences for {self.split_type} split.")
 
 
 
 
 
 
 
 
 
109
 
110
  def __len__(self):
111
- return len(self.inputs)
112
-
113
- def __getitem__(self, idx):
114
- return (
115
- torch.tensor(self.inputs[idx], dtype=torch.long),
116
- torch.tensor(self.labels[idx], dtype=torch.long)
117
- )
118
 
119
  # ============================= GET LOGITS UTIL =============================
120
 
@@ -130,6 +122,7 @@ def get_logits_from_model(model, inputs):
130
  def evaluate(model, dataloader, criterion, device):
131
  model.eval()
132
  total_loss = 0.0
 
133
  with torch.no_grad():
134
  for inputs, targets in dataloader:
135
  inputs, targets = inputs.to(device), targets.to(device)
@@ -138,7 +131,8 @@ def evaluate(model, dataloader, criterion, device):
138
  targets = targets.contiguous().view(-1)[:logits.shape[0]]
139
  loss = criterion(logits, targets)
140
  total_loss += loss.item()
141
- avg_loss = total_loss / len(dataloader)
 
142
  model.train()
143
  return avg_loss
144
 
@@ -185,16 +179,17 @@ def train():
185
  except AttributeError:
186
  print("⚠️ Warning: model.gradient_checkpointing_enable() not found on JIT model. Training will proceed without GC.")
187
 
188
- train_dataset = TextDataset(DATASET_PATH, seq_len=TRAIN_SEQ_LEN, split_type='train', val_ratio=VAL_SPLIT_RATIO)
189
- val_dataset = TextDataset(DATASET_PATH, seq_len=TRAIN_SEQ_LEN, split_type='val', val_ratio=VAL_SPLIT_RATIO)
190
 
191
- train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
192
- val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)
 
193
 
194
  optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
195
  criterion = nn.CrossEntropyLoss()
196
 
197
- total_steps = len(train_dataloader) * EPOCHS
198
  print(f"\n=== BEGINNING LONG-TERM TRAINING ===")
199
  print(f"Epochs: {EPOCHS} | Steps (Train): {total_steps} | Examples (Train): {len(train_dataset)}")
200
  print(f"Batch Size (Effective): {BATCH_SIZE} | Precision: FP32")
@@ -223,10 +218,10 @@ def train():
223
  pbar.set_postfix({
224
  "loss": f"{loss_val:.3f}",
225
  "ppl": f"{math.exp(min(loss_val, 10)):.1f}",
226
- "step": f"{global_step}/{total_steps}"
227
  })
228
 
229
- avg_train_loss = epoch_loss / len(train_dataloader)
230
  print(f" [TRAIN] Average loss: {avg_train_loss:.3f} | PPL: {math.exp(avg_train_loss):.1f}")
231
 
232
  print(" [VALIDATION] Starting evaluation...")
 
 
 
 
 
 
 
1
  import os
2
  import torch
3
  import torch.nn as nn
4
  import torch.optim as optim
5
+ from torch.utils.data import IterableDataset, DataLoader
6
  from transformers import GPT2TokenizerFast
7
  from tqdm import tqdm
8
  import shutil
 
60
  OUTPUT_DIR = Path("build/fine_tuning_output")
61
  MODEL_SAVE_NAME = "gpt_finetuned.script.pt"
62
 
63
+ device = torch.device("cpu")
64
  print(f"Using device: {device}")
65
 
66
+ # ============================= DATASET (LAZY) =============================
67
 
68
+ class LazyTextDataset(IterableDataset):
69
+ """Lazy memory-efficient dataset, splits on-the-fly into train and val."""
70
  def __init__(self, text_file, seq_len=TRAIN_SEQ_LEN, tokenizer_name="gpt2", split_type='train', val_ratio=VAL_SPLIT_RATIO):
71
  self.seq_len = seq_len
72
  self.tokenizer = GPT2TokenizerFast.from_pretrained(tokenizer_name)
73
  self.tokenizer.pad_token = self.tokenizer.eos_token
74
+ self.text_file = text_file
75
  self.split_type = split_type
76
+ self.val_ratio = val_ratio
77
+
78
+ print(f"Loading and tokenizing text from {text_file}")
79
+ with open(text_file, "r", encoding="utf-8") as f:
80
+ self.data = f.read()
81
+ self.tokens = self.tokenizer.encode(self.data)
82
+
83
+ # Work out split indices
84
+ total_tokens = len(self.tokens) - 1 # because label sequence shifted
85
+ total_batches = total_tokens // seq_len
86
+ val_size = int(total_batches * self.val_ratio)
87
+ train_size = total_batches - val_size
88
+ if split_type == 'train':
89
+ self.start = 0
90
+ self.stop = train_size
91
+ elif split_type == 'val':
92
+ self.start = train_size
93
+ self.stop = train_size + val_size
 
 
 
 
 
 
 
94
  else:
95
+ raise ValueError(f"split_type should be 'train' or 'val', got {split_type}")
96
+ self.total_sequences = self.stop - self.start
97
+ print(f"Lazy dataset: {self.total_sequences:,} sequences for {split_type} split (from {total_batches:,} total)")
98
+
99
+ def __iter__(self):
100
+ for i in range(self.start * self.seq_len, self.stop * self.seq_len, self.seq_len):
101
+ # Make sure last batch fits
102
+ if i + self.seq_len + 1 > len(self.tokens):
103
+ break
104
+ input_seq = torch.tensor(self.tokens[i : i + self.seq_len], dtype=torch.long)
105
+ label_seq = torch.tensor(self.tokens[i + 1 : i + self.seq_len + 1], dtype=torch.long)
106
+ yield input_seq, label_seq
107
 
108
  def __len__(self):
109
+ return self.total_sequences
 
 
 
 
 
 
110
 
111
  # ============================= GET LOGITS UTIL =============================
112
 
 
122
  def evaluate(model, dataloader, criterion, device):
123
  model.eval()
124
  total_loss = 0.0
125
+ count = 0
126
  with torch.no_grad():
127
  for inputs, targets in dataloader:
128
  inputs, targets = inputs.to(device), targets.to(device)
 
131
  targets = targets.contiguous().view(-1)[:logits.shape[0]]
132
  loss = criterion(logits, targets)
133
  total_loss += loss.item()
134
+ count += 1
135
+ avg_loss = total_loss / max(count, 1)
136
  model.train()
137
  return avg_loss
138
 
 
179
  except AttributeError:
180
  print("⚠️ Warning: model.gradient_checkpointing_enable() not found on JIT model. Training will proceed without GC.")
181
 
182
+ train_dataset = LazyTextDataset(DATASET_PATH, seq_len=TRAIN_SEQ_LEN, split_type='train', val_ratio=VAL_SPLIT_RATIO)
183
+ val_dataset = LazyTextDataset(DATASET_PATH, seq_len=TRAIN_SEQ_LEN, split_type='val', val_ratio=VAL_SPLIT_RATIO)
184
 
185
+ # IterableDataset: must use drop_last=True and shuffle=False, num_workers=0 on CPU
186
+ train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True, num_workers=0)
187
+ val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True, num_workers=0)
188
 
189
  optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
190
  criterion = nn.CrossEntropyLoss()
191
 
192
+ total_steps = (len(train_dataset) // BATCH_SIZE) * EPOCHS
193
  print(f"\n=== BEGINNING LONG-TERM TRAINING ===")
194
  print(f"Epochs: {EPOCHS} | Steps (Train): {total_steps} | Examples (Train): {len(train_dataset)}")
195
  print(f"Batch Size (Effective): {BATCH_SIZE} | Precision: FP32")
 
218
  pbar.set_postfix({
219
  "loss": f"{loss_val:.3f}",
220
  "ppl": f"{math.exp(min(loss_val, 10)):.1f}",
221
+ "step": f"{global_step}"
222
  })
223
 
224
+ avg_train_loss = epoch_loss / max(1, len(train_dataset) // BATCH_SIZE)
225
  print(f" [TRAIN] Average loss: {avg_train_loss:.3f} | PPL: {math.exp(avg_train_loss):.1f}")
226
 
227
  print(" [VALIDATION] Starting evaluation...")