|
|
|
|
|
import os
|
|
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
|
|
|
|
import sys
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
import wandb
|
|
|
from torch.utils.data import TensorDataset, DataLoader
|
|
|
from transformers import AutoModelForCausalLM, get_linear_schedule_with_warmup
|
|
|
from peft import PeftModel
|
|
|
from torch.cuda.amp import GradScaler, autocast
|
|
|
from tqdm.auto import tqdm
|
|
|
from multiprocessing import freeze_support
|
|
|
import shutil
|
|
|
import collections
|
|
|
|
|
|
def main():
|
|
|
|
|
|
PRET_FILE = "pretokenized_queries.pt"
|
|
|
MODEL_NAME = "google/gemma-3-1b-pt"
|
|
|
LORA_DIR = "phase2_triplet_amp/final"
|
|
|
BATCH_SIZE = 200
|
|
|
LR = 1e-5
|
|
|
WEIGHT_DECAY = 0.01
|
|
|
NUM_EPOCHS = 1
|
|
|
TEMP = 0.05
|
|
|
OUTPUT_DIR = "phase3_self_contrast_wandb"
|
|
|
GRAD_CLIP_NORM = 1.0
|
|
|
SEED = 42
|
|
|
WANDB_PROJECT = "query-encoder-phase3"
|
|
|
|
|
|
|
|
|
SAVE_INTERVAL = 1000
|
|
|
KEEP_LAST_CKPTS = 5
|
|
|
|
|
|
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
torch.manual_seed(SEED)
|
|
|
|
|
|
|
|
|
wandb.init(
|
|
|
project=WANDB_PROJECT,
|
|
|
config={
|
|
|
"model_name": MODEL_NAME, "lora_dir": LORA_DIR, "batch_size": BATCH_SIZE,
|
|
|
"lr": LR, "num_epochs": NUM_EPOCHS, "seed": SEED,
|
|
|
"save_interval_steps": SAVE_INTERVAL,
|
|
|
"keep_last_checkpoints": KEEP_LAST_CKPTS,
|
|
|
}
|
|
|
)
|
|
|
|
|
|
|
|
|
print(f"Loading pretokenized queries from {PRET_FILE}...")
|
|
|
data = torch.load(PRET_FILE, weights_only=True)
|
|
|
input_ids = data["input_ids"]
|
|
|
attention_mask = data["attention_mask"]
|
|
|
dataset = TensorDataset(input_ids, attention_mask)
|
|
|
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
|
|
|
print(f"Loaded {len(dataset)} samples.")
|
|
|
|
|
|
|
|
|
print(f"Loading base model '{MODEL_NAME}' and LoRA adapters from '{LORA_DIR}'...")
|
|
|
base = AutoModelForCausalLM.from_pretrained(MODEL_NAME, attn_implementation="eager")
|
|
|
peft = PeftModel.from_pretrained(base, LORA_DIR).to(device)
|
|
|
print("LoRA adapters loaded.")
|
|
|
|
|
|
|
|
|
class GemmaSelfContrast(nn.Module):
|
|
|
def __init__(self, peft_model):
|
|
|
super().__init__()
|
|
|
self.peft = peft_model
|
|
|
hs = peft_model.base_model.config.hidden_size
|
|
|
self.proj = nn.Sequential(
|
|
|
nn.Linear(hs, 512),
|
|
|
nn.ReLU(),
|
|
|
nn.Linear(512, hs),
|
|
|
)
|
|
|
def forward(self, ids, mask):
|
|
|
out = self.peft.base_model(
|
|
|
input_ids=ids,
|
|
|
attention_mask=mask,
|
|
|
output_hidden_states=True,
|
|
|
return_dict=True
|
|
|
)
|
|
|
h = out.hidden_states[-1].mean(dim=1)
|
|
|
h = torch.nan_to_num(h, nan=0.0, posinf=1e-6, neginf=-1e-6)
|
|
|
z = self.proj(h)
|
|
|
z = torch.nan_to_num(z, nan=0.0, posinf=1e-6, neginf=-1e-6)
|
|
|
norm = z.norm(p=2, dim=1, keepdim=True).clamp_min(1e-6)
|
|
|
return z / norm
|
|
|
|
|
|
model = GemmaSelfContrast(peft).to(device)
|
|
|
print("Encoder model (with projection head) initialized.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
|
|
|
total_steps = len(loader) * NUM_EPOCHS
|
|
|
scheduler = get_linear_schedule_with_warmup(
|
|
|
optimizer,
|
|
|
num_warmup_steps=int(0.1 * total_steps),
|
|
|
num_training_steps=total_steps
|
|
|
)
|
|
|
scaler = GradScaler()
|
|
|
print(f"Training will run for {total_steps} steps.")
|
|
|
|
|
|
|
|
|
checkpoint_paths = collections.deque(maxlen=KEEP_LAST_CKPTS)
|
|
|
|
|
|
|
|
|
model.train()
|
|
|
global_step = 0
|
|
|
for epoch in range(1, NUM_EPOCHS + 1):
|
|
|
total_loss = 0.0
|
|
|
pbar = tqdm(loader, desc=f"Epoch {epoch}", unit="batch")
|
|
|
for ids, mask in pbar:
|
|
|
ids, mask = ids.to(device), mask.to(device)
|
|
|
|
|
|
with autocast():
|
|
|
e1 = model(ids, mask)
|
|
|
e2 = model(ids, mask)
|
|
|
emb = torch.cat([e1, e2], dim=0)
|
|
|
sim = (emb @ emb.T) / TEMP
|
|
|
|
|
|
mask_eye = torch.eye(sim.size(0), device=device, dtype=torch.bool)
|
|
|
sim = sim.masked_fill(mask_eye, float('-inf'))
|
|
|
|
|
|
B = e1.size(0)
|
|
|
labels = torch.cat([
|
|
|
torch.arange(B, device=device) + B,
|
|
|
torch.arange(B, device=device)
|
|
|
], dim=0)
|
|
|
loss = F.cross_entropy(sim, labels)
|
|
|
|
|
|
optimizer.zero_grad()
|
|
|
scaler.scale(loss).backward()
|
|
|
scaler.unscale_(optimizer)
|
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP_NORM)
|
|
|
scaler.step(optimizer)
|
|
|
scaler.update()
|
|
|
scheduler.step()
|
|
|
|
|
|
|
|
|
wandb.log({
|
|
|
"train/loss": loss.item(),
|
|
|
"train/lr": scheduler.get_last_lr()[0],
|
|
|
"train/epoch": epoch,
|
|
|
"train/global_step": global_step
|
|
|
}, step=global_step)
|
|
|
|
|
|
pbar.set_postfix({"loss": f"{loss.item():.4f}"})
|
|
|
|
|
|
|
|
|
|
|
|
if (global_step + 1) % SAVE_INTERVAL == 0:
|
|
|
|
|
|
ckpt_dir = os.path.join(OUTPUT_DIR, f"checkpoint-step-{global_step + 1}")
|
|
|
os.makedirs(ckpt_dir, exist_ok=True)
|
|
|
|
|
|
print(f"\nSaving checkpoint to {ckpt_dir}...")
|
|
|
|
|
|
peft.save_pretrained(ckpt_dir)
|
|
|
|
|
|
torch.save(model.proj.state_dict(), os.path.join(ckpt_dir, "encoder_proj.pth"))
|
|
|
|
|
|
|
|
|
if len(checkpoint_paths) == KEEP_LAST_CKPTS:
|
|
|
oldest_ckpt = checkpoint_paths.popleft()
|
|
|
if os.path.isdir(oldest_ckpt):
|
|
|
print(f"Removing old checkpoint: {oldest_ckpt}")
|
|
|
shutil.rmtree(oldest_ckpt, ignore_errors=True)
|
|
|
checkpoint_paths.append(ckpt_dir)
|
|
|
print("Checkpoint saved and old ones managed.")
|
|
|
|
|
|
|
|
|
global_step += 1
|
|
|
total_loss += loss.item()
|
|
|
|
|
|
avg_loss = total_loss / len(loader)
|
|
|
print(f"Epoch {epoch} training complete. Avg loss: {avg_loss:.6f}")
|
|
|
|
|
|
wandb.log({"train/epoch_avg_loss": avg_loss, "epoch": epoch}, step=global_step)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\nTraining finished. Saving final model to 'final' directory...")
|
|
|
final_dir = os.path.join(OUTPUT_DIR, "final")
|
|
|
os.makedirs(final_dir, exist_ok=True)
|
|
|
|
|
|
|
|
|
peft.save_pretrained(final_dir)
|
|
|
|
|
|
|
|
|
torch.save(model.proj.state_dict(), os.path.join(final_dir, "encoder_proj.pth"))
|
|
|
|
|
|
print(f"Phase 3 complete. LoRA adapters and projection head saved to {final_dir}")
|
|
|
|
|
|
|
|
|
wandb.finish()
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
freeze_support()
|
|
|
main() |