Spaces:
Paused
Paused
Michał Paliński commited on
Commit ·
f2411ad
1
Parent(s): baa5ebd
custom training loop — bypass SentenceTransformerTrainer
Browse filesNVEmbedModel has non-standard forward() incompatible with ST.
Uses model.encode() with torch.enable_grad() for differentiable
embeddings + manual MNRL loss + AdamW with warmup scheduler.
No sentence-transformers dependency needed for training.
Made-with: Cursor
- requirements.txt +0 -1
- run.py +109 -88
requirements.txt
CHANGED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
transformers==4.45.2
|
| 2 |
-
sentence-transformers>=3.3.0,<3.4.0
|
| 3 |
datasets>=2.14.0
|
| 4 |
huggingface-hub>=0.20.0
|
| 5 |
accelerate>=0.25.0
|
|
|
|
| 1 |
transformers==4.45.2
|
|
|
|
| 2 |
datasets>=2.14.0
|
| 3 |
huggingface-hub>=0.20.0
|
| 4 |
accelerate>=0.25.0
|
run.py
CHANGED
|
@@ -1,12 +1,13 @@
|
|
| 1 |
"""
|
| 2 |
-
NV-Embed-v2: Evaluate base → Fine-tune with LoRA → Evaluate fine-tuned.
|
| 3 |
-
|
| 4 |
"""
|
| 5 |
|
| 6 |
import json
|
| 7 |
import random
|
| 8 |
import logging
|
| 9 |
import os
|
|
|
|
| 10 |
import numpy as np
|
| 11 |
import pandas as pd
|
| 12 |
import faiss
|
|
@@ -42,21 +43,22 @@ LORA_R = 16
|
|
| 42 |
LORA_ALPHA = 32
|
| 43 |
LORA_DROPOUT = 0.1
|
| 44 |
LORA_TARGETS = ["q_proj", "v_proj", "k_proj", "o_proj"]
|
|
|
|
| 45 |
|
| 46 |
|
| 47 |
# ═══════════════════════════════════════════════════════════════════════════
|
| 48 |
-
# EVALUATION (uses AutoModel.encode
|
| 49 |
# ═══════════════════════════════════════════════════════════════════════════
|
| 50 |
|
| 51 |
-
def evaluate_with_automodel(
|
| 52 |
-
"""Load model via AutoModel and run ESCO benchmark."""
|
| 53 |
from transformers import AutoModel
|
| 54 |
|
| 55 |
-
|
|
|
|
| 56 |
|
| 57 |
-
logger.info(f"Loading {
|
| 58 |
model = AutoModel.from_pretrained(
|
| 59 |
-
|
| 60 |
torch_dtype=torch.float16, device_map="auto",
|
| 61 |
)
|
| 62 |
model.eval()
|
|
@@ -111,7 +113,7 @@ def evaluate_with_automodel(model_name, token):
|
|
| 111 |
|
| 112 |
met = {
|
| 113 |
"test_set": test_name,
|
| 114 |
-
"method":
|
| 115 |
"matchable_rows": matchable_count,
|
| 116 |
"accuracy_top1": round(top1/matchable_count, 6),
|
| 117 |
"accuracy_top3": round(top3/matchable_count, 6),
|
|
@@ -126,12 +128,13 @@ def evaluate_with_automodel(model_name, token):
|
|
| 126 |
)
|
| 127 |
|
| 128 |
del model
|
|
|
|
| 129 |
torch.cuda.empty_cache()
|
| 130 |
return all_metrics
|
| 131 |
|
| 132 |
|
| 133 |
# ═══════════════════════════════════════════════════════════════════════════
|
| 134 |
-
# TRAINING (
|
| 135 |
# ═══════════════════════════════════════════════════════════════════════════
|
| 136 |
|
| 137 |
def augment_with_context(sentences, prob=0.5):
|
|
@@ -144,98 +147,116 @@ def augment_with_context(sentences, prob=0.5):
|
|
| 144 |
return augmented
|
| 145 |
|
| 146 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
def train_model(token):
|
| 148 |
-
from
|
| 149 |
-
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer
|
| 150 |
-
from sentence_transformers.losses import MultipleNegativesRankingLoss
|
| 151 |
-
from sentence_transformers.training_args import SentenceTransformerTrainingArguments
|
| 152 |
-
from sentence_transformers.evaluation import InformationRetrievalEvaluator
|
| 153 |
from peft import LoraConfig, get_peft_model
|
|
|
|
| 154 |
|
| 155 |
-
logger.info(f"\n{'='*60}\n FINE-TUNING WITH LoRA\n{'='*60}")
|
| 156 |
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
|
|
|
|
|
|
| 160 |
)
|
| 161 |
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
)
|
| 167 |
-
model
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
model
|
| 171 |
-
model[0].auto_model.print_trainable_parameters()
|
| 172 |
-
|
| 173 |
-
model.prompts = {"anchor": QUERY_INSTRUCTION, "positive": ""}
|
| 174 |
|
| 175 |
# Dataset
|
| 176 |
logger.info(f"Loading dataset: {DATASET_ID}")
|
| 177 |
raw = load_dataset(DATASET_ID, split="train").shuffle(seed=42)
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
|
|
|
|
|
|
|
|
|
| 196 |
)
|
| 197 |
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
|
| 230 |
-
|
| 231 |
-
|
|
|
|
|
|
|
| 232 |
|
| 233 |
-
|
| 234 |
-
if push:
|
| 235 |
logger.info(f"Pushing to Hub: {HUB_MODEL_ID}")
|
| 236 |
-
model.push_to_hub(HUB_MODEL_ID,
|
|
|
|
| 237 |
|
| 238 |
-
del model,
|
|
|
|
| 239 |
torch.cuda.empty_cache()
|
| 240 |
logger.info("Training complete.")
|
| 241 |
|
|
@@ -256,14 +277,14 @@ def main():
|
|
| 256 |
logger.info(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
|
| 257 |
|
| 258 |
# Phase 1: Evaluate base model
|
| 259 |
-
base_metrics = evaluate_with_automodel(MODEL_ID, token)
|
| 260 |
|
| 261 |
# Phase 2: Fine-tune with LoRA
|
| 262 |
train_model(token)
|
| 263 |
|
| 264 |
# Phase 3: Evaluate fine-tuned model
|
| 265 |
-
|
| 266 |
-
ft_metrics = evaluate_with_automodel(
|
| 267 |
|
| 268 |
# Summary
|
| 269 |
all_metrics = base_metrics + ft_metrics
|
|
|
|
| 1 |
"""
|
| 2 |
+
NV-Embed-v2: Evaluate base → Fine-tune with LoRA (custom loop) → Evaluate fine-tuned.
|
| 3 |
+
Custom training loop because NVEmbedModel.forward() is incompatible with SentenceTransformerTrainer.
|
| 4 |
"""
|
| 5 |
|
| 6 |
import json
|
| 7 |
import random
|
| 8 |
import logging
|
| 9 |
import os
|
| 10 |
+
import gc
|
| 11 |
import numpy as np
|
| 12 |
import pandas as pd
|
| 13 |
import faiss
|
|
|
|
| 43 |
LORA_ALPHA = 32
|
| 44 |
LORA_DROPOUT = 0.1
|
| 45 |
LORA_TARGETS = ["q_proj", "v_proj", "k_proj", "o_proj"]
|
| 46 |
+
TEMPERATURE = 20.0
|
| 47 |
|
| 48 |
|
| 49 |
# ═══════════════════════════════════════════════════════════════════════════
|
| 50 |
+
# EVALUATION (uses AutoModel.encode — no gradients needed)
|
| 51 |
# ═══════════════════════════════════════════════════════════════════════════
|
| 52 |
|
| 53 |
+
def evaluate_with_automodel(model_name_or_path, token, method_label=None):
|
|
|
|
| 54 |
from transformers import AutoModel
|
| 55 |
|
| 56 |
+
label = method_label or model_name_or_path.split("/")[-1]
|
| 57 |
+
logger.info(f"\n{'='*60}\n EVALUATING: {label}\n{'='*60}")
|
| 58 |
|
| 59 |
+
logger.info(f"Loading {model_name_or_path}...")
|
| 60 |
model = AutoModel.from_pretrained(
|
| 61 |
+
model_name_or_path, trust_remote_code=True, token=token,
|
| 62 |
torch_dtype=torch.float16, device_map="auto",
|
| 63 |
)
|
| 64 |
model.eval()
|
|
|
|
| 113 |
|
| 114 |
met = {
|
| 115 |
"test_set": test_name,
|
| 116 |
+
"method": label,
|
| 117 |
"matchable_rows": matchable_count,
|
| 118 |
"accuracy_top1": round(top1/matchable_count, 6),
|
| 119 |
"accuracy_top3": round(top3/matchable_count, 6),
|
|
|
|
| 128 |
)
|
| 129 |
|
| 130 |
del model
|
| 131 |
+
gc.collect()
|
| 132 |
torch.cuda.empty_cache()
|
| 133 |
return all_metrics
|
| 134 |
|
| 135 |
|
| 136 |
# ═══════════════════════════════════════════════════════════════════════════
|
| 137 |
+
# TRAINING (custom loop — model.encode() with torch.enable_grad)
|
| 138 |
# ═══════════════════════════════════════════════════════════════════════════
|
| 139 |
|
| 140 |
def augment_with_context(sentences, prob=0.5):
|
|
|
|
| 147 |
return augmented
|
| 148 |
|
| 149 |
|
| 150 |
+
def mnrl_loss(anchor_emb, positive_emb, temperature=TEMPERATURE):
|
| 151 |
+
"""Multiple Negatives Ranking Loss: in-batch contrastive."""
|
| 152 |
+
scores = torch.mm(anchor_emb, positive_emb.t()) * temperature
|
| 153 |
+
labels = torch.arange(scores.size(0), device=scores.device)
|
| 154 |
+
return F.cross_entropy(scores, labels)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
def train_model(token):
|
| 158 |
+
from transformers import AutoModel
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
from peft import LoraConfig, get_peft_model
|
| 160 |
+
from datasets import load_dataset
|
| 161 |
|
| 162 |
+
logger.info(f"\n{'='*60}\n FINE-TUNING WITH LoRA (custom loop)\n{'='*60}")
|
| 163 |
|
| 164 |
+
# Load model
|
| 165 |
+
logger.info(f"Loading {MODEL_ID}...")
|
| 166 |
+
model = AutoModel.from_pretrained(
|
| 167 |
+
MODEL_ID, trust_remote_code=True, token=token,
|
| 168 |
+
torch_dtype=torch.bfloat16,
|
| 169 |
)
|
| 170 |
|
| 171 |
+
# Apply LoRA
|
| 172 |
+
lora_config = LoraConfig(
|
| 173 |
+
r=LORA_R, lora_alpha=LORA_ALPHA, target_modules=LORA_TARGETS,
|
| 174 |
+
lora_dropout=LORA_DROPOUT, bias="none",
|
| 175 |
)
|
| 176 |
+
model = get_peft_model(model, lora_config)
|
| 177 |
+
model.print_trainable_parameters()
|
| 178 |
+
model.cuda()
|
| 179 |
+
model.train()
|
|
|
|
|
|
|
|
|
|
| 180 |
|
| 181 |
# Dataset
|
| 182 |
logger.info(f"Loading dataset: {DATASET_ID}")
|
| 183 |
raw = load_dataset(DATASET_ID, split="train").shuffle(seed=42)
|
| 184 |
+
anchors_raw = raw["sentence"]
|
| 185 |
+
positives_raw = raw["skill"]
|
| 186 |
+
logger.info(f"Dataset: {len(anchors_raw)} pairs")
|
| 187 |
+
|
| 188 |
+
logger.info("Augmenting anchors...")
|
| 189 |
+
anchors = augment_with_context(anchors_raw, prob=AUGMENT_PROB)
|
| 190 |
+
positives = positives_raw
|
| 191 |
+
|
| 192 |
+
# Optimizer (only LoRA params)
|
| 193 |
+
trainable_params = [p for p in model.parameters() if p.requires_grad]
|
| 194 |
+
optimizer = torch.optim.AdamW(trainable_params, lr=LR, weight_decay=0.01)
|
| 195 |
+
|
| 196 |
+
total_micro_steps = len(anchors) // TRAIN_BATCH
|
| 197 |
+
total_optim_steps = total_micro_steps // GRAD_ACCUM
|
| 198 |
+
logger.info(f"Micro-steps: {total_micro_steps}, Optimizer steps: {total_optim_steps}")
|
| 199 |
+
logger.info(f"Warmup: {WARMUP_STEPS} steps, LR: {LR}")
|
| 200 |
+
|
| 201 |
+
# LR scheduler with warmup
|
| 202 |
+
from transformers import get_linear_schedule_with_warmup
|
| 203 |
+
scheduler = get_linear_schedule_with_warmup(
|
| 204 |
+
optimizer, num_warmup_steps=WARMUP_STEPS, num_training_steps=total_optim_steps
|
| 205 |
)
|
| 206 |
|
| 207 |
+
# Training loop
|
| 208 |
+
indices = list(range(len(anchors)))
|
| 209 |
+
random.shuffle(indices)
|
| 210 |
+
|
| 211 |
+
optimizer.zero_grad()
|
| 212 |
+
running_loss = 0.0
|
| 213 |
+
micro_step = 0
|
| 214 |
+
|
| 215 |
+
for i in range(0, len(indices) - TRAIN_BATCH + 1, TRAIN_BATCH):
|
| 216 |
+
batch_idx = indices[i:i+TRAIN_BATCH]
|
| 217 |
+
batch_anchors = [anchors[j] for j in batch_idx]
|
| 218 |
+
batch_positives = [positives[j] for j in batch_idx]
|
| 219 |
+
|
| 220 |
+
# Get embeddings WITH gradients via torch.enable_grad()
|
| 221 |
+
with torch.enable_grad(), torch.amp.autocast("cuda", dtype=torch.bfloat16):
|
| 222 |
+
anchor_emb = model.encode(batch_anchors, instruction=QUERY_INSTRUCTION, max_length=512)
|
| 223 |
+
positive_emb = model.encode(batch_positives, instruction="", max_length=512)
|
| 224 |
+
|
| 225 |
+
anchor_emb = F.normalize(anchor_emb, p=2, dim=1)
|
| 226 |
+
positive_emb = F.normalize(positive_emb, p=2, dim=1)
|
| 227 |
+
|
| 228 |
+
loss = mnrl_loss(anchor_emb, positive_emb) / GRAD_ACCUM
|
| 229 |
+
|
| 230 |
+
loss.backward()
|
| 231 |
+
running_loss += loss.item()
|
| 232 |
+
micro_step += 1
|
| 233 |
+
|
| 234 |
+
if micro_step % GRAD_ACCUM == 0:
|
| 235 |
+
torch.nn.utils.clip_grad_norm_(trainable_params, 1.0)
|
| 236 |
+
optimizer.step()
|
| 237 |
+
scheduler.step()
|
| 238 |
+
optimizer.zero_grad()
|
| 239 |
+
|
| 240 |
+
optim_step = micro_step // GRAD_ACCUM
|
| 241 |
+
avg_loss = running_loss
|
| 242 |
+
running_loss = 0.0
|
| 243 |
+
|
| 244 |
+
if optim_step % 50 == 0:
|
| 245 |
+
lr_now = scheduler.get_last_lr()[0]
|
| 246 |
+
logger.info(f" step {optim_step}/{total_optim_steps} loss={avg_loss:.4f} lr={lr_now:.2e}")
|
| 247 |
|
| 248 |
+
# Save
|
| 249 |
+
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
| 250 |
+
logger.info(f"Saving to {OUTPUT_DIR}...")
|
| 251 |
+
model.save_pretrained(OUTPUT_DIR)
|
| 252 |
|
| 253 |
+
if HUB_MODEL_ID:
|
|
|
|
| 254 |
logger.info(f"Pushing to Hub: {HUB_MODEL_ID}")
|
| 255 |
+
model.push_to_hub(HUB_MODEL_ID, token=token)
|
| 256 |
+
logger.info("Pushed.")
|
| 257 |
|
| 258 |
+
del model, optimizer, trainable_params
|
| 259 |
+
gc.collect()
|
| 260 |
torch.cuda.empty_cache()
|
| 261 |
logger.info("Training complete.")
|
| 262 |
|
|
|
|
| 277 |
logger.info(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
|
| 278 |
|
| 279 |
# Phase 1: Evaluate base model
|
| 280 |
+
base_metrics = evaluate_with_automodel(MODEL_ID, token, "nv-embed-v2")
|
| 281 |
|
| 282 |
# Phase 2: Fine-tune with LoRA
|
| 283 |
train_model(token)
|
| 284 |
|
| 285 |
# Phase 3: Evaluate fine-tuned model
|
| 286 |
+
ft_source = HUB_MODEL_ID if HUB_MODEL_ID else OUTPUT_DIR
|
| 287 |
+
ft_metrics = evaluate_with_automodel(ft_source, token, "nv-embed-v2-ft")
|
| 288 |
|
| 289 |
# Summary
|
| 290 |
all_metrics = base_metrics + ft_metrics
|