|
|
from __future__ import annotations |
|
|
|
|
|
import argparse |
|
|
import json |
|
|
import random |
|
|
from pathlib import Path |
|
|
from typing import List |
|
|
|
|
|
import torch |
|
|
from sentence_transformers import CrossEncoder |
|
|
from torch.utils.data import DataLoader, Dataset |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
class PairwiseDataset(Dataset): |
|
|
def __init__(self, path: str, max_len: int = 256): |
|
|
self.samples = [] |
|
|
with open(path) as f: |
|
|
for line in f: |
|
|
if line.strip(): |
|
|
self.samples.append(json.loads(line)) |
|
|
self.max_len = max_len |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.samples) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
item = self.samples[idx] |
|
|
return item["query"], item["pos_text"], item["neg_text"] |
|
|
|
|
|
|
|
|
def pairwise_loss(model, tokenizer, batch, device, max_len: int): |
|
|
queries, pos_texts, neg_texts = batch |
|
|
|
|
|
queries = [q or "" for q in queries] |
|
|
pos_texts = [p or "" for p in pos_texts] |
|
|
neg_texts = [n or "" for n in neg_texts] |
|
|
enc_pos = tokenizer( |
|
|
list(queries), |
|
|
list(pos_texts), |
|
|
padding="max_length", |
|
|
truncation="longest_first", |
|
|
max_length=max_len, |
|
|
return_tensors="pt", |
|
|
).to(device) |
|
|
enc_neg = tokenizer( |
|
|
list(queries), |
|
|
list(neg_texts), |
|
|
padding="max_length", |
|
|
truncation="longest_first", |
|
|
max_length=max_len, |
|
|
return_tensors="pt", |
|
|
).to(device) |
|
|
scores_pos = model(**enc_pos).logits.view(-1) |
|
|
scores_neg = model(**enc_neg).logits.view(-1) |
|
|
|
|
|
scores_pos = torch.clamp(scores_pos, -20.0, 20.0) |
|
|
scores_neg = torch.clamp(scores_neg, -20.0, 20.0) |
|
|
scores_pos = torch.nan_to_num(scores_pos, nan=0.0, posinf=0.0, neginf=0.0) |
|
|
scores_neg = torch.nan_to_num(scores_neg, nan=0.0, posinf=0.0, neginf=0.0) |
|
|
diff = torch.clamp(scores_pos - scores_neg, -20.0, 20.0) |
|
|
if diff.numel() == 0: |
|
|
return None |
|
|
|
|
|
return F.binary_cross_entropy_with_logits(diff, torch.ones_like(diff)) |
|
|
|
|
|
|
|
|
def train(model_name: str, train_path: str, val_path: str, epochs: int = 1, lr: float = 1e-5, batch_size: int = 4, max_len: int = 256, output_dir: str = "models/reranker_crossenc/v0.1.0"): |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
ce = CrossEncoder(model_name, max_length=max_len, device=device) |
|
|
model = ce.model |
|
|
tokenizer = ce.tokenizer |
|
|
|
|
|
train_ds = PairwiseDataset(train_path, max_len=max_len) |
|
|
val_ds = PairwiseDataset(val_path, max_len=max_len) if Path(val_path).exists() else None |
|
|
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True) |
|
|
val_loader = DataLoader(val_ds, batch_size=batch_size) if val_ds else None |
|
|
|
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=lr) |
|
|
best_loss = float("inf") |
|
|
|
|
|
for epoch in range(epochs): |
|
|
model.train() |
|
|
total_loss = 0.0 |
|
|
steps = 0 |
|
|
for batch in train_loader: |
|
|
optimizer.zero_grad() |
|
|
loss = pairwise_loss(model, tokenizer, batch, device, max_len) |
|
|
if loss is None: |
|
|
continue |
|
|
if not torch.isfinite(loss): |
|
|
continue |
|
|
loss.backward() |
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
|
|
optimizer.step() |
|
|
total_loss += loss.item() |
|
|
steps += 1 |
|
|
avg_loss = total_loss / max(1, steps) |
|
|
|
|
|
val_loss = None |
|
|
if val_loader: |
|
|
model.eval() |
|
|
with torch.no_grad(): |
|
|
vloss = 0.0 |
|
|
vsteps = 0 |
|
|
for batch in val_loader: |
|
|
l = pairwise_loss(model, tokenizer, batch, device, max_len) |
|
|
if l is None or not torch.isfinite(l): |
|
|
continue |
|
|
vloss += l.item() |
|
|
vsteps += 1 |
|
|
val_loss = vloss / max(1, vsteps) |
|
|
print(f"Epoch {epoch+1}: train_loss={avg_loss:.4f} val_loss={val_loss:.4f}" if val_loss is not None else f"Epoch {epoch+1}: train_loss={avg_loss:.4f}") |
|
|
|
|
|
if val_loss is not None and val_loss < best_loss: |
|
|
best_loss = val_loss |
|
|
Path(output_dir).mkdir(parents=True, exist_ok=True) |
|
|
ce.save(output_dir) |
|
|
|
|
|
Path(output_dir).mkdir(parents=True, exist_ok=True) |
|
|
ce.save(output_dir) |
|
|
with open(Path(output_dir) / "train_config.json", "w") as f: |
|
|
json.dump( |
|
|
{ |
|
|
"model_name": model_name, |
|
|
"train_path": train_path, |
|
|
"val_path": val_path, |
|
|
"epochs": epochs, |
|
|
"lr": lr, |
|
|
"batch_size": batch_size, |
|
|
"max_len": max_len, |
|
|
}, |
|
|
f, |
|
|
indent=2, |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser(description="Fine-tune cross-encoder reranker (pairwise).") |
|
|
parser.add_argument("--model", default="cross-encoder/ms-marco-MiniLM-L-6-v2") |
|
|
parser.add_argument("--train", default="data/reranker/pairwise_train.jsonl") |
|
|
parser.add_argument("--val", default="data/reranker/pairwise_val.jsonl") |
|
|
parser.add_argument("--epochs", type=int, default=1) |
|
|
parser.add_argument("--lr", type=float, default=1e-5) |
|
|
parser.add_argument("--batch-size", type=int, default=4) |
|
|
parser.add_argument("--max-len", type=int, default=256) |
|
|
parser.add_argument("--output-dir", default="models/reranker_crossenc/v0.1.0") |
|
|
args = parser.parse_args() |
|
|
|
|
|
train( |
|
|
model_name=args.model, |
|
|
train_path=args.train, |
|
|
val_path=args.val, |
|
|
epochs=args.epochs, |
|
|
lr=args.lr, |
|
|
batch_size=args.batch_size, |
|
|
max_len=args.max_len, |
|
|
output_dir=args.output_dir, |
|
|
) |
|
|
|