dejanseo commited on
Commit
80c7523
·
verified ·
1 Parent(s): c7f53f1

Upload train_stage_2.py

Browse files
Files changed (1) hide show
  1. train_stage_2.py +163 -0
train_stage_2.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import os
3
+ import csv
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.utils.data import Dataset, DataLoader
7
+ from transformers import (
8
+ AutoTokenizer,
9
+ AutoModelForCausalLM,
10
+ get_linear_schedule_with_warmup
11
+ )
12
+ from peft import PeftModel
13
+ from torch.cuda.amp import autocast, GradScaler
14
+ from tqdm.auto import tqdm
15
+ from multiprocessing import freeze_support
16
+
17
+ class TripletDataset(Dataset):
18
+ def __init__(self, path):
19
+ self.samples = []
20
+ with open(path, newline="") as f:
21
+ reader = csv.DictReader(f)
22
+ for row in reader:
23
+ a_ids = torch.tensor(list(map(int, row["a_ids"].split())), dtype=torch.long)
24
+ a_mask = torch.tensor(list(map(int, row["a_mask"].split())), dtype=torch.long)
25
+ p_ids = torch.tensor(list(map(int, row["p_ids"].split())), dtype=torch.long)
26
+ p_mask = torch.tensor(list(map(int, row["p_mask"].split())), dtype=torch.long)
27
+ n_ids = torch.tensor(list(map(int, row["n_ids"].split())), dtype=torch.long)
28
+ n_mask = torch.tensor(list(map(int, row["n_mask"].split())), dtype=torch.long)
29
+ self.samples.append((a_ids, a_mask, p_ids, p_mask, n_ids, n_mask))
30
+
31
+ def __len__(self):
32
+ return len(self.samples)
33
+
34
+ def __getitem__(self, idx):
35
+ return self.samples[idx]
36
+
37
+ class GemmaTripletModel(nn.Module):
38
+ def __init__(self, peft_model):
39
+ super().__init__()
40
+ self.peft = peft_model
41
+ H = peft_model.config.hidden_size
42
+ self.proj = nn.Sequential(
43
+ nn.Linear(H, 512),
44
+ nn.ReLU(),
45
+ nn.Linear(512, 256),
46
+ )
47
+
48
+ def forward(self, ids, mask):
49
+ out = self.peft.base_model( # PeftModel stores underlying model as .base_model
50
+ input_ids=ids,
51
+ attention_mask=mask,
52
+ output_hidden_states=True,
53
+ return_dict=True
54
+ )
55
+ last = out.hidden_states[-1] # (B, T, H)
56
+ pooled = last.mean(dim=1) # mean pooling
57
+ z = self.proj(pooled)
58
+ norm = z.norm(p=2, dim=1, keepdim=True).clamp_min(1e-6)
59
+ return z / norm
60
+
61
+ def collate_fn(batch):
62
+ return tuple(torch.stack(x) for x in zip(*batch))
63
+
64
+ def main():
65
+ # --- Config ---
66
+ MODEL_NAME = "google/gemma-3-1b-pt"
67
+ STAGE1_DIR = "stage1_simcse/final"
68
+ TRAIN_FILE = "train.csv"
69
+ VAL_FILE = "val.csv"
70
+ BATCH_SIZE = 12
71
+ LR = 1e-5
72
+ WEIGHT_DECAY = 0.01
73
+ NUM_EPOCHS = 3
74
+ MARGIN = 0.2
75
+ OUTPUT_DIR = "phase2_triplet_amp"
76
+ SEED = 42
77
+
78
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
79
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
80
+ torch.manual_seed(SEED)
81
+
82
+ # --- Tokenizer & PEFT Model (load Stage 1) ---
83
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
84
+ base = AutoModelForCausalLM.from_pretrained(MODEL_NAME, attn_implementation="eager")
85
+ peft_model = PeftModel.from_pretrained(base, STAGE1_DIR) # loads your Stage 1 LoRA weights
86
+
87
+ # --- Embed + Projector ---
88
+ model = GemmaTripletModel(peft_model).to(device)
89
+
90
+ # --- Datasets & Loaders ---
91
+ train_ds = TripletDataset(TRAIN_FILE)
92
+ val_ds = TripletDataset(VAL_FILE)
93
+ train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
94
+ val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)
95
+
96
+ # --- Optimizer, Scheduler, AMP ---
97
+ optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
98
+ total_steps = len(train_loader) * NUM_EPOCHS
99
+ scheduler = get_linear_schedule_with_warmup(
100
+ optimizer,
101
+ num_warmup_steps=int(0.1 * total_steps),
102
+ num_training_steps=total_steps
103
+ )
104
+ scaler = GradScaler()
105
+
106
+ # --- Loss ---
107
+ triplet_loss = nn.TripletMarginLoss(margin=MARGIN, p=2)
108
+
109
+ # --- Training Loop ---
110
+ for epoch in range(1, NUM_EPOCHS + 1):
111
+ model.train()
112
+ running_loss = 0.0
113
+ for a_ids, a_mask, p_ids, p_mask, n_ids, n_mask in tqdm(train_loader, desc=f"Train {epoch}", unit="batch"):
114
+ a_ids, a_mask = a_ids.to(device), a_mask.to(device)
115
+ p_ids, p_mask = p_ids.to(device), p_mask.to(device)
116
+ n_ids, n_mask = n_ids.to(device), n_mask.to(device)
117
+
118
+ optimizer.zero_grad()
119
+ with autocast():
120
+ emb_a = model(a_ids, a_mask)
121
+ emb_p = model(p_ids, p_mask)
122
+ emb_n = model(n_ids, n_mask)
123
+ loss = triplet_loss(emb_a, emb_p, emb_n)
124
+
125
+ scaler.scale(loss).backward()
126
+ scaler.step(optimizer)
127
+ scaler.update()
128
+ scheduler.step()
129
+ running_loss += loss.item()
130
+
131
+ print(f"Epoch {epoch} Train Loss: {running_loss/len(train_loader):.6f}")
132
+
133
+ # --- Validation ---
134
+ model.eval()
135
+ val_loss = 0.0
136
+ with torch.no_grad():
137
+ for a_ids, a_mask, p_ids, p_mask, n_ids, n_mask in tqdm(val_loader, desc=f"Val {epoch}", unit="batch"):
138
+ a_ids, a_mask = a_ids.to(device), a_mask.to(device)
139
+ p_ids, p_mask = p_ids.to(device), p_mask.to(device)
140
+ n_ids, n_mask = n_ids.to(device), n_mask.to(device)
141
+ with autocast():
142
+ emb_a = model(a_ids, a_mask)
143
+ emb_p = model(p_ids, p_mask)
144
+ emb_n = model(n_ids, n_mask)
145
+ val_loss += triplet_loss(emb_a, emb_p, emb_n).item()
146
+
147
+ print(f"Epoch {epoch} Val Loss: {val_loss/len(val_loader):.6f}")
148
+
149
+ # --- Checkpoint LoRA only ---
150
+ ckpt_dir = os.path.join(OUTPUT_DIR, f"epoch{epoch}")
151
+ peft_model.save_pretrained(ckpt_dir)
152
+ tokenizer.save_pretrained(ckpt_dir)
153
+
154
+ # --- Final Save ---
155
+ final_dir = os.path.join(OUTPUT_DIR, "final")
156
+ os.makedirs(final_dir, exist_ok=True)
157
+ peft_model.save_pretrained(final_dir)
158
+ tokenizer.save_pretrained(final_dir)
159
+ print("Phase 2 complete. Checkpoints in", OUTPUT_DIR)
160
+
161
+ if __name__ == "__main__":
162
+ freeze_support()
163
+ main()