dejanseo commited on
Commit
9b9f13c
·
verified ·
1 Parent(s): d7edc32

Upload train3.py

Browse files
Files changed (1) hide show
  1. train3.py +95 -14
train3.py CHANGED
@@ -7,43 +7,66 @@ import sys
7
  import torch
8
  import torch.nn as nn
9
  import torch.nn.functional as F
 
10
  from torch.utils.data import TensorDataset, DataLoader
11
  from transformers import AutoModelForCausalLM, get_linear_schedule_with_warmup
12
  from peft import PeftModel
13
  from torch.cuda.amp import GradScaler, autocast
14
  from tqdm.auto import tqdm
15
  from multiprocessing import freeze_support
 
 
16
 
17
  def main():
18
  # --- Config ---
19
  PRET_FILE = "pretokenized_queries.pt"
20
  MODEL_NAME = "google/gemma-3-1b-pt"
21
- LORA_DIR = "phase2_triplet_amp/final"
22
- BATCH_SIZE = 64
23
  LR = 1e-5
24
  WEIGHT_DECAY = 0.01
25
- NUM_EPOCHS = 1
26
  TEMP = 0.05
27
- OUTPUT_DIR = "phase3_self_contrast"
28
  GRAD_CLIP_NORM = 1.0
29
  SEED = 42
 
 
 
 
 
30
 
31
  os.makedirs(OUTPUT_DIR, exist_ok=True)
32
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
  torch.manual_seed(SEED)
34
 
 
 
 
 
 
 
 
 
 
 
 
35
  # --- Load pretokenized queries safely ---
 
36
  data = torch.load(PRET_FILE, weights_only=True)
37
  input_ids = data["input_ids"]
38
  attention_mask = data["attention_mask"]
39
  dataset = TensorDataset(input_ids, attention_mask)
40
  loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
 
41
 
42
- # --- Load base model + LoRA adapters ---
 
43
  base = AutoModelForCausalLM.from_pretrained(MODEL_NAME, attn_implementation="eager")
44
  peft = PeftModel.from_pretrained(base, LORA_DIR).to(device)
 
45
 
46
- # --- Projection head now outputs hidden_size instead of 256 ---
47
  class GemmaSelfContrast(nn.Module):
48
  def __init__(self, peft_model):
49
  super().__init__()
@@ -69,6 +92,9 @@ def main():
69
  return z / norm
70
 
71
  model = GemmaSelfContrast(peft).to(device)
 
 
 
72
 
73
  # --- Optimizer, scheduler, AMP scaler ---
74
  optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
@@ -79,12 +105,18 @@ def main():
79
  num_training_steps=total_steps
80
  )
81
  scaler = GradScaler()
 
 
 
 
82
 
83
  # --- Training loop ---
84
  model.train()
 
85
  for epoch in range(1, NUM_EPOCHS + 1):
86
  total_loss = 0.0
87
- for ids, mask in tqdm(loader, desc=f"Epoch {epoch}", unit="batch"):
 
88
  ids, mask = ids.to(device), mask.to(device)
89
 
90
  with autocast():
@@ -105,23 +137,72 @@ def main():
105
 
106
  optimizer.zero_grad()
107
  scaler.scale(loss).backward()
108
- scaler.unscale_(optimizer)
109
  torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP_NORM)
110
  scaler.step(optimizer)
111
  scaler.update()
112
  scheduler.step()
113
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  total_loss += loss.item()
115
 
116
  avg_loss = total_loss / len(loader)
117
- print(f"Epoch {epoch} avg loss: {avg_loss:.6f}")
118
-
119
- # --- Save only LoRA adapters ---
 
 
 
 
 
120
  final_dir = os.path.join(OUTPUT_DIR, "final")
121
  os.makedirs(final_dir, exist_ok=True)
 
 
122
  peft.save_pretrained(final_dir)
123
- print("Phase 3 complete. LoRA adapters saved to", final_dir)
 
 
 
 
 
 
 
 
124
 
125
  if __name__ == "__main__":
126
  freeze_support()
127
- main()
 
7
  import torch
8
  import torch.nn as nn
9
  import torch.nn.functional as F
10
+ import wandb
11
  from torch.utils.data import TensorDataset, DataLoader
12
  from transformers import AutoModelForCausalLM, get_linear_schedule_with_warmup
13
  from peft import PeftModel
14
  from torch.cuda.amp import GradScaler, autocast
15
  from tqdm.auto import tqdm
16
  from multiprocessing import freeze_support
17
+ import shutil # Import shutil for removing old checkpoints
18
+ import collections # Import collections for deque
19
 
20
  def main():
21
  # --- Config ---
22
  PRET_FILE = "pretokenized_queries.pt"
23
  MODEL_NAME = "google/gemma-3-1b-pt"
24
+ LORA_DIR = "phase2_triplet_amp/final" # Adapters from previous stage
25
+ BATCH_SIZE = 200
26
  LR = 1e-5
27
  WEIGHT_DECAY = 0.01
28
+ NUM_EPOCHS = 1 # As per our discussion, 1 epoch is likely sufficient given fast convergence
29
  TEMP = 0.05
30
+ OUTPUT_DIR = "phase3_self_contrast_wandb"
31
  GRAD_CLIP_NORM = 1.0
32
  SEED = 42
33
+ WANDB_PROJECT = "query-encoder-phase3"
34
+
35
+ # --- Checkpointing Configuration ---
36
+ SAVE_INTERVAL = 1000 # Save a checkpoint every N steps
37
+ KEEP_LAST_CKPTS = 5 # Keep only the last N checkpoints (to save disk space)
38
 
39
  os.makedirs(OUTPUT_DIR, exist_ok=True)
40
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
41
  torch.manual_seed(SEED)
42
 
43
+ # --- Initialize WandB ---
44
+ wandb.init(
45
+ project=WANDB_PROJECT,
46
+ config={
47
+ "model_name": MODEL_NAME, "lora_dir": LORA_DIR, "batch_size": BATCH_SIZE,
48
+ "lr": LR, "num_epochs": NUM_EPOCHS, "seed": SEED,
49
+ "save_interval_steps": SAVE_INTERVAL,
50
+ "keep_last_checkpoints": KEEP_LAST_CKPTS,
51
+ }
52
+ )
53
+
54
  # --- Load pretokenized queries safely ---
55
+ print(f"Loading pretokenized queries from {PRET_FILE}...")
56
  data = torch.load(PRET_FILE, weights_only=True)
57
  input_ids = data["input_ids"]
58
  attention_mask = data["attention_mask"]
59
  dataset = TensorDataset(input_ids, attention_mask)
60
  loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
61
+ print(f"Loaded {len(dataset)} samples.")
62
 
63
+ # --- Load base model + LoRA adapters from previous stage ---
64
+ print(f"Loading base model '{MODEL_NAME}' and LoRA adapters from '{LORA_DIR}'...")
65
  base = AutoModelForCausalLM.from_pretrained(MODEL_NAME, attn_implementation="eager")
66
  peft = PeftModel.from_pretrained(base, LORA_DIR).to(device)
67
+ print("LoRA adapters loaded.")
68
 
69
+ # --- Projection head now outputs hidden_size ---
70
  class GemmaSelfContrast(nn.Module):
71
  def __init__(self, peft_model):
72
  super().__init__()
 
92
  return z / norm
93
 
94
  model = GemmaSelfContrast(peft).to(device)
95
+ print("Encoder model (with projection head) initialized.")
96
+ # Watch the model with wandb (optional, can be slow, but good for tracking gradients)
97
+ # wandb.watch(model, log="all", log_freq=100) # Commented out due to potential slowdown
98
 
99
  # --- Optimizer, scheduler, AMP scaler ---
100
  optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
 
105
  num_training_steps=total_steps
106
  )
107
  scaler = GradScaler()
108
+ print(f"Training will run for {total_steps} steps.")
109
+
110
+ # Deque to manage checkpoint paths and enforce keeping only the last N
111
+ checkpoint_paths = collections.deque(maxlen=KEEP_LAST_CKPTS)
112
 
113
  # --- Training loop ---
114
  model.train()
115
+ global_step = 0
116
  for epoch in range(1, NUM_EPOCHS + 1):
117
  total_loss = 0.0
118
+ pbar = tqdm(loader, desc=f"Epoch {epoch}", unit="batch")
119
+ for ids, mask in pbar:
120
  ids, mask = ids.to(device), mask.to(device)
121
 
122
  with autocast():
 
137
 
138
  optimizer.zero_grad()
139
  scaler.scale(loss).backward()
140
+ scaler.unscale_(optimizer) # Unscale gradients before clipping
141
  torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP_NORM)
142
  scaler.step(optimizer)
143
  scaler.update()
144
  scheduler.step()
145
+
146
+ # --- Log metrics to WandB at every step ---
147
+ wandb.log({
148
+ "train/loss": loss.item(),
149
+ "train/lr": scheduler.get_last_lr()[0],
150
+ "train/epoch": epoch,
151
+ "train/global_step": global_step
152
+ }, step=global_step)
153
+
154
+ pbar.set_postfix({"loss": f"{loss.item():.4f}"})
155
+
156
+ # --- PERIODIC SAVING BLOCK ---
157
+ # Save checkpoint every SAVE_INTERVAL steps
158
+ if (global_step + 1) % SAVE_INTERVAL == 0:
159
+ # Create a unique directory for this checkpoint
160
+ ckpt_dir = os.path.join(OUTPUT_DIR, f"checkpoint-step-{global_step + 1}")
161
+ os.makedirs(ckpt_dir, exist_ok=True)
162
+
163
+ print(f"\nSaving checkpoint to {ckpt_dir}...")
164
+ # Save the PEFT adapters
165
+ peft.save_pretrained(ckpt_dir)
166
+ # Save the trained projection head's state dictionary
167
+ torch.save(model.proj.state_dict(), os.path.join(ckpt_dir, "encoder_proj.pth"))
168
+
169
+ # Manage old checkpoints
170
+ if len(checkpoint_paths) == KEEP_LAST_CKPTS:
171
+ oldest_ckpt = checkpoint_paths.popleft() # Remove the oldest path from deque
172
+ if os.path.isdir(oldest_ckpt):
173
+ print(f"Removing old checkpoint: {oldest_ckpt}")
174
+ shutil.rmtree(oldest_ckpt, ignore_errors=True) # Delete the directory
175
+ checkpoint_paths.append(ckpt_dir) # Add new checkpoint path
176
+ print("Checkpoint saved and old ones managed.")
177
+ # --- END PERIODIC SAVING ---
178
+
179
+ global_step += 1
180
  total_loss += loss.item()
181
 
182
  avg_loss = total_loss / len(loader)
183
+ print(f"Epoch {epoch} training complete. Avg loss: {avg_loss:.6f}")
184
+ # Log average epoch loss as well
185
+ wandb.log({"train/epoch_avg_loss": avg_loss, "epoch": epoch}, step=global_step)
186
+
187
+ # --- Final Save for the "final" directory ---
188
+ # This ensures that even if you stop mid-epoch (after a checkpoint)
189
+ # or don't stop, there's always a clear 'final' model.
190
+ print("\nTraining finished. Saving final model to 'final' directory...")
191
  final_dir = os.path.join(OUTPUT_DIR, "final")
192
  os.makedirs(final_dir, exist_ok=True)
193
+
194
+ # Save the LoRA adapters
195
  peft.save_pretrained(final_dir)
196
+
197
+ # Save the trained projection head's state dictionary
198
+ torch.save(model.proj.state_dict(), os.path.join(final_dir, "encoder_proj.pth"))
199
+
200
+ print(f"Phase 3 complete. LoRA adapters and projection head saved to {final_dir}")
201
+
202
+ # --- Finalize WandB run ---
203
+ wandb.finish()
204
+
205
 
206
  if __name__ == "__main__":
207
  freeze_support()
208
+ main()