| | |
| | import os |
| | os.environ["CUDA_LAUNCH_BLOCKING"] = "1" |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from torch.utils.data import DataLoader |
| | from transformers import ( |
| | AutoTokenizer, |
| | AutoModelForCausalLM, |
| | get_linear_schedule_with_warmup |
| | ) |
| | from peft import LoraConfig, get_peft_model, TaskType |
| | from datasets import load_dataset |
| | from tqdm.auto import tqdm |
| | from multiprocessing import freeze_support |
| |
|
| | def main(): |
| | |
| | MODEL_NAME = "google/gemma-3-1b-pt" |
| | DATA_FILE = "text.txt" |
| | BATCH_SIZE = 12 |
| | MAX_LENGTH = 128 |
| | LR = 1e-5 |
| | WEIGHT_DECAY = 0.01 |
| | NUM_EPOCHS = 1 |
| | VAL_RATIO = 0.1 |
| | LORA_R = 8 |
| | LORA_ALPHA = 16 |
| | LORA_DROPOUT = 0.0 |
| | PROJ_HIDDEN = 512 |
| | PROJ_OUT = 256 |
| | TEMP = 0.05 |
| | OUTPUT_DIR = "stage1_simcse" |
| | GRAD_CLIP_NORM = 1.0 |
| | SIM_CLAMP_MIN = -10.0 |
| | SIM_CLAMP_MAX = 10.0 |
| | SEED = 42 |
| |
|
| | os.makedirs(OUTPUT_DIR, exist_ok=True) |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
|
| | |
| | tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True) |
| | base_model = AutoModelForCausalLM.from_pretrained( |
| | MODEL_NAME, |
| | attn_implementation="eager" |
| | ) |
| |
|
| | |
| | lora_cfg = LoraConfig( |
| | task_type=TaskType.CAUSAL_LM, |
| | inference_mode=False, |
| | r=LORA_R, |
| | lora_alpha=LORA_ALPHA, |
| | lora_dropout=LORA_DROPOUT, |
| | target_modules=["q_proj", "v_proj"], |
| | ) |
| | model_lora = get_peft_model(base_model, lora_cfg) |
| |
|
| | |
| | class GemmaSimCSE(nn.Module): |
| | def __init__(self, base): |
| | super().__init__() |
| | self.base = base |
| | hs = base.config.hidden_size |
| | self.proj = nn.Sequential( |
| | nn.Linear(hs, PROJ_HIDDEN), |
| | nn.ReLU(), |
| | nn.Linear(PROJ_HIDDEN, PROJ_OUT), |
| | ) |
| |
|
| | def forward(self, input_ids, attention_mask): |
| | out = self.base( |
| | input_ids=input_ids, |
| | attention_mask=attention_mask, |
| | output_hidden_states=True, |
| | return_dict=True |
| | ) |
| | hidden = out.hidden_states[-1] |
| | emb = hidden.mean(dim=1) |
| | emb = torch.nan_to_num(emb, nan=0.0, posinf=1e-6, neginf=-1e-6) |
| | z = self.proj(emb) |
| | z = torch.nan_to_num(z, nan=0.0, posinf=1e-6, neginf=-1e-6) |
| | norm = z.norm(p=2, dim=1, keepdim=True).clamp_min(1e-6) |
| | return z / norm |
| |
|
| | model = GemmaSimCSE(model_lora).to(device) |
| | torch.autograd.set_detect_anomaly(True) |
| |
|
| | |
| | raw = load_dataset("text", data_files={"train": DATA_FILE}, split="train") |
| | raw = raw.filter(lambda x: x["text"].strip() != "") |
| | split = raw.train_test_split(test_size=VAL_RATIO, seed=SEED) |
| | train_ds = split["train"] |
| | val_ds = split["test"] |
| |
|
| | |
| | def tokenize_fn(batch): |
| | toks = tokenizer( |
| | batch["text"], |
| | max_length=MAX_LENGTH, |
| | truncation=True, |
| | padding="max_length" |
| | ) |
| | return {"input_ids": toks["input_ids"], "attention_mask": toks["attention_mask"]} |
| |
|
| | train_ds = train_ds.map( |
| | tokenize_fn, |
| | batched=True, |
| | batch_size=1000, |
| | num_proc=4, |
| | remove_columns=["text"] |
| | ) |
| | val_ds = val_ds.map( |
| | tokenize_fn, |
| | batched=True, |
| | batch_size=1000, |
| | num_proc=4, |
| | remove_columns=["text"] |
| | ) |
| |
|
| | train_ds.set_format(type="torch", columns=["input_ids", "attention_mask"]) |
| | val_ds.set_format(type="torch", columns=["input_ids", "attention_mask"]) |
| |
|
| | train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True) |
| | val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False) |
| |
|
| | |
| | optimizer = torch.optim.AdamW( |
| | model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY |
| | ) |
| | total_steps = len(train_loader) * NUM_EPOCHS |
| | scheduler = get_linear_schedule_with_warmup( |
| | optimizer, |
| | num_warmup_steps=int(0.1 * total_steps), |
| | num_training_steps=total_steps |
| | ) |
| |
|
| | |
| | for epoch in range(1, NUM_EPOCHS + 1): |
| | |
| | model.train() |
| | train_loss = 0.0 |
| | for batch in tqdm(train_loader, desc=f"Train Epoch {epoch}", unit="batch"): |
| | ids = batch["input_ids"].to(device) |
| | mask = batch["attention_mask"].to(device) |
| |
|
| | emb1 = model(ids, mask) |
| | emb2 = model(ids, mask) |
| | emb = torch.cat([emb1, emb2], dim=0) |
| | sim = (emb @ emb.T) / TEMP |
| | sim = sim.clamp(SIM_CLAMP_MIN, SIM_CLAMP_MAX) |
| |
|
| | |
| | sim.fill_diagonal_(-1e9) |
| |
|
| | B = emb1.size(0) |
| | |
| | labels = torch.cat([ |
| | torch.arange(B, device=device) + B, |
| | torch.arange(B, device=device) |
| | ], dim=0) |
| |
|
| | loss = F.cross_entropy(sim, labels) |
| | optimizer.zero_grad() |
| | loss.backward() |
| | torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP_NORM) |
| | optimizer.step() |
| | scheduler.step() |
| |
|
| | train_loss += loss.item() |
| |
|
| | avg_train_loss = train_loss / len(train_loader) |
| | print(f"Epoch {epoch} training complete. avg train loss: {avg_train_loss:.6f}") |
| |
|
| | |
| | model.eval() |
| | val_loss = 0.0 |
| | with torch.no_grad(): |
| | for batch in tqdm(val_loader, desc=f"Validate Epoch {epoch}", unit="batch"): |
| | ids = batch["input_ids"].to(device) |
| | mask = batch["attention_mask"].to(device) |
| |
|
| | emb1 = model(ids, mask) |
| | emb2 = model(ids, mask) |
| | emb = torch.cat([emb1, emb2], dim=0) |
| | sim = (emb @ emb.T) / TEMP |
| | sim = sim.clamp(SIM_CLAMP_MIN, SIM_CLAMP_MAX) |
| | sim.fill_diagonal_(-1e9) |
| |
|
| | B = emb1.size(0) |
| | labels = torch.cat([ |
| | torch.arange(B, device=device) + B, |
| | torch.arange(B, device=device) |
| | ], dim=0) |
| |
|
| | loss = F.cross_entropy(sim, labels) |
| | val_loss += loss.item() |
| |
|
| | avg_val_loss = val_loss / len(val_loader) |
| | print(f"Epoch {epoch} validation complete. avg val loss: {avg_val_loss:.6f}") |
| |
|
| | |
| | ckpt_dir = os.path.join(OUTPUT_DIR, f"epoch{epoch}") |
| | model_lora.save_pretrained(ckpt_dir) |
| | tokenizer.save_pretrained(ckpt_dir) |
| |
|
| | |
| | final_dir = os.path.join(OUTPUT_DIR, "final") |
| | os.makedirs(final_dir, exist_ok=True) |
| | model_lora.save_pretrained(final_dir) |
| | tokenizer.save_pretrained(final_dir) |
| | print("Training and validation complete. Final model saved to", final_dir) |
| |
|
| | if __name__ == "__main__": |
| | freeze_support() |
| | main() |
| |
|