|
|
|
|
|
import os
|
|
|
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
|
|
import sys
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
from torch.utils.data import DataLoader
|
|
|
from transformers import (
|
|
|
AutoTokenizer,
|
|
|
AutoModelForCausalLM,
|
|
|
get_linear_schedule_with_warmup
|
|
|
)
|
|
|
from peft import LoraConfig, get_peft_model, TaskType
|
|
|
from datasets import load_dataset
|
|
|
from tqdm.auto import tqdm
|
|
|
from multiprocessing import freeze_support
|
|
|
|
|
|
def main():
|
|
|
|
|
|
MODEL_NAME = "google/gemma-3-1b-pt"
|
|
|
DATA_FILE = "text.txt"
|
|
|
BATCH_SIZE = 12
|
|
|
MAX_LENGTH = 128
|
|
|
LR = 1e-5
|
|
|
WEIGHT_DECAY = 0.01
|
|
|
NUM_EPOCHS = 1
|
|
|
VAL_RATIO = 0.1
|
|
|
LORA_R = 8
|
|
|
LORA_ALPHA = 16
|
|
|
LORA_DROPOUT = 0.0
|
|
|
PROJ_HIDDEN = 512
|
|
|
TEMP = 0.05
|
|
|
OUTPUT_DIR = "stage1_simcse"
|
|
|
GRAD_CLIP_NORM = 1.0
|
|
|
SIM_CLAMP_MIN = -10.0
|
|
|
SIM_CLAMP_MAX = 10.0
|
|
|
SEED = 42
|
|
|
|
|
|
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
|
|
|
if device.type == "cuda":
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
|
torch.backends.cudnn.allow_tf32 = True
|
|
|
torch.backends.cudnn.benchmark = True
|
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
|
|
|
base_model = AutoModelForCausalLM.from_pretrained(
|
|
|
MODEL_NAME,
|
|
|
attn_implementation="eager"
|
|
|
)
|
|
|
|
|
|
|
|
|
lora_cfg = LoraConfig(
|
|
|
task_type=TaskType.CAUSAL_LM,
|
|
|
inference_mode=False,
|
|
|
r=LORA_R,
|
|
|
lora_alpha=LORA_ALPHA,
|
|
|
lora_dropout=LORA_DROPOUT,
|
|
|
target_modules=["q_proj", "v_proj"],
|
|
|
)
|
|
|
model_lora = get_peft_model(base_model, lora_cfg)
|
|
|
|
|
|
|
|
|
class GemmaSimCSE(nn.Module):
|
|
|
def __init__(self, base):
|
|
|
super().__init__()
|
|
|
self.base = base
|
|
|
hs = base.config.hidden_size
|
|
|
self.proj = nn.Sequential(
|
|
|
nn.Linear(hs, PROJ_HIDDEN),
|
|
|
nn.ReLU(),
|
|
|
nn.Linear(PROJ_HIDDEN, hs),
|
|
|
)
|
|
|
|
|
|
def forward(self, input_ids, attention_mask):
|
|
|
out = self.base(
|
|
|
input_ids=input_ids,
|
|
|
attention_mask=attention_mask,
|
|
|
output_hidden_states=True,
|
|
|
return_dict=True
|
|
|
)
|
|
|
hidden = out.hidden_states[-1]
|
|
|
emb = hidden.mean(dim=1)
|
|
|
emb = torch.nan_to_num(emb, nan=0.0, posinf=1e-6, neginf=-1e-6)
|
|
|
z = self.proj(emb)
|
|
|
z = torch.nan_to_num(z, nan=0.0, posinf=1e-6, neginf=-1e-6)
|
|
|
norm = z.norm(p=2, dim=1, keepdim=True).clamp_min(1e-6)
|
|
|
return z / norm
|
|
|
|
|
|
model = GemmaSimCSE(model_lora).to(device)
|
|
|
torch.autograd.set_detect_anomaly(True)
|
|
|
|
|
|
|
|
|
raw = load_dataset("text", data_files={"train": DATA_FILE}, split="train")
|
|
|
raw = raw.filter(lambda x: x["text"].strip() != "")
|
|
|
split = raw.train_test_split(test_size=VAL_RATIO, seed=SEED)
|
|
|
train_ds = split["train"]
|
|
|
val_ds = split["test"]
|
|
|
|
|
|
|
|
|
def tokenize_fn(batch):
|
|
|
toks = tokenizer(
|
|
|
batch["text"],
|
|
|
max_length=MAX_LENGTH,
|
|
|
truncation=True,
|
|
|
padding="max_length"
|
|
|
)
|
|
|
return {"input_ids": toks["input_ids"], "attention_mask": toks["attention_mask"]}
|
|
|
|
|
|
train_ds = train_ds.map(
|
|
|
tokenize_fn,
|
|
|
batched=True,
|
|
|
batch_size=1000,
|
|
|
num_proc=4,
|
|
|
remove_columns=["text"]
|
|
|
)
|
|
|
val_ds = val_ds.map(
|
|
|
tokenize_fn,
|
|
|
batched=True,
|
|
|
batch_size=1000,
|
|
|
num_proc=4,
|
|
|
remove_columns=["text"]
|
|
|
)
|
|
|
|
|
|
train_ds.set_format(type="torch", columns=["input_ids", "attention_mask"])
|
|
|
val_ds.set_format(type="torch", columns=["input_ids", "attention_mask"])
|
|
|
|
|
|
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
|
|
|
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)
|
|
|
|
|
|
|
|
|
optimizer = torch.optim.AdamW(
|
|
|
model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY
|
|
|
)
|
|
|
total_steps = len(train_loader) * NUM_EPOCHS
|
|
|
scheduler = get_linear_schedule_with_warmup(
|
|
|
optimizer,
|
|
|
num_warmup_steps=int(0.1 * total_steps),
|
|
|
num_training_steps=total_steps
|
|
|
)
|
|
|
|
|
|
|
|
|
for epoch in range(1, NUM_EPOCHS + 1):
|
|
|
|
|
|
model.train()
|
|
|
train_loss = 0.0
|
|
|
for batch in tqdm(train_loader, desc=f"Train Epoch {epoch}", unit="batch"):
|
|
|
ids = batch["input_ids"].to(device)
|
|
|
mask = batch["attention_mask"].to(device)
|
|
|
|
|
|
emb1 = model(ids, mask)
|
|
|
emb2 = model(ids, mask)
|
|
|
emb = torch.cat([emb1, emb2], dim=0)
|
|
|
sim = (emb @ emb.T) / TEMP
|
|
|
sim = sim.clamp(SIM_CLAMP_MIN, SIM_CLAMP_MAX)
|
|
|
sim.fill_diagonal_(-1e9)
|
|
|
|
|
|
B = emb1.size(0)
|
|
|
labels = torch.cat([
|
|
|
torch.arange(B, device=device) + B,
|
|
|
torch.arange(B, device=device)
|
|
|
], dim=0)
|
|
|
|
|
|
loss = F.cross_entropy(sim, labels)
|
|
|
optimizer.zero_grad()
|
|
|
loss.backward()
|
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP_NORM)
|
|
|
optimizer.step()
|
|
|
scheduler.step()
|
|
|
train_loss += loss.item()
|
|
|
|
|
|
avg_train_loss = train_loss / len(train_loader)
|
|
|
print(f"Epoch {epoch} training complete. avg train loss: {avg_train_loss:.6f}")
|
|
|
|
|
|
|
|
|
model.eval()
|
|
|
val_loss = 0.0
|
|
|
with torch.no_grad():
|
|
|
for batch in tqdm(val_loader, desc=f"Validate Epoch {epoch}", unit="batch"):
|
|
|
ids = batch["input_ids"].to(device)
|
|
|
mask = batch["attention_mask"].to(device)
|
|
|
|
|
|
emb1 = model(ids, mask)
|
|
|
emb2 = model(ids, mask)
|
|
|
emb = torch.cat([emb1, emb2], dim=0)
|
|
|
sim = (emb @ emb.T) / TEMP
|
|
|
sim = sim.clamp(SIM_CLAMP_MIN, SIM_CLAMP_MAX)
|
|
|
sim.fill_diagonal_(-1e9)
|
|
|
|
|
|
B = emb1.size(0)
|
|
|
labels = torch.cat([
|
|
|
torch.arange(B, device=device) + B,
|
|
|
torch.arange(B, device=device)
|
|
|
], dim=0)
|
|
|
|
|
|
loss = F.cross_entropy(sim, labels)
|
|
|
val_loss += loss.item()
|
|
|
|
|
|
avg_val_loss = val_loss / len(val_loader)
|
|
|
print(f"Epoch {epoch} validation complete. avg val loss: {avg_val_loss:.6f}")
|
|
|
|
|
|
|
|
|
ckpt_dir = os.path.join(OUTPUT_DIR, f"epoch{epoch}")
|
|
|
model_lora.save_pretrained(ckpt_dir)
|
|
|
tokenizer.save_pretrained(ckpt_dir)
|
|
|
|
|
|
|
|
|
final_dir = os.path.join(OUTPUT_DIR, "final")
|
|
|
os.makedirs(final_dir, exist_ok=True)
|
|
|
model_lora.save_pretrained(final_dir)
|
|
|
tokenizer.save_pretrained(final_dir)
|
|
|
print("Training and validation complete. Final model saved to", final_dir)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
freeze_support()
|
|
|
main()
|
|
|
|