Michał Paliński commited on
Commit
f2411ad
·
1 Parent(s): baa5ebd

custom training loop — bypass SentenceTransformerTrainer

Browse files

NVEmbedModel has non-standard forward() incompatible with ST.
Uses model.encode() with torch.enable_grad() for differentiable
embeddings + manual MNRL loss + AdamW with warmup scheduler.
No sentence-transformers dependency needed for training.

Made-with: Cursor

Files changed (2) hide show
  1. requirements.txt +0 -1
  2. run.py +109 -88
requirements.txt CHANGED
@@ -1,5 +1,4 @@
1
  transformers==4.45.2
2
- sentence-transformers>=3.3.0,<3.4.0
3
  datasets>=2.14.0
4
  huggingface-hub>=0.20.0
5
  accelerate>=0.25.0
 
1
  transformers==4.45.2
 
2
  datasets>=2.14.0
3
  huggingface-hub>=0.20.0
4
  accelerate>=0.25.0
run.py CHANGED
@@ -1,12 +1,13 @@
1
  """
2
- NV-Embed-v2: Evaluate base → Fine-tune with LoRA → Evaluate fine-tuned.
3
- All on same GPU in one run.
4
  """
5
 
6
  import json
7
  import random
8
  import logging
9
  import os
 
10
  import numpy as np
11
  import pandas as pd
12
  import faiss
@@ -42,21 +43,22 @@ LORA_R = 16
42
  LORA_ALPHA = 32
43
  LORA_DROPOUT = 0.1
44
  LORA_TARGETS = ["q_proj", "v_proj", "k_proj", "o_proj"]
 
45
 
46
 
47
  # ═══════════════════════════════════════════════════════════════════════════
48
- # EVALUATION (uses AutoModel.encode directly — no sentence-transformers)
49
  # ═══════════════════════════════════════════════════════════════════════════
50
 
51
- def evaluate_with_automodel(model_name, token):
52
- """Load model via AutoModel and run ESCO benchmark."""
53
  from transformers import AutoModel
54
 
55
- logger.info(f"\n{'='*60}\n EVALUATING: {model_name}\n{'='*60}")
 
56
 
57
- logger.info(f"Loading {model_name}...")
58
  model = AutoModel.from_pretrained(
59
- model_name, trust_remote_code=True, token=token,
60
  torch_dtype=torch.float16, device_map="auto",
61
  )
62
  model.eval()
@@ -111,7 +113,7 @@ def evaluate_with_automodel(model_name, token):
111
 
112
  met = {
113
  "test_set": test_name,
114
- "method": model_name.split("/")[-1],
115
  "matchable_rows": matchable_count,
116
  "accuracy_top1": round(top1/matchable_count, 6),
117
  "accuracy_top3": round(top3/matchable_count, 6),
@@ -126,12 +128,13 @@ def evaluate_with_automodel(model_name, token):
126
  )
127
 
128
  del model
 
129
  torch.cuda.empty_cache()
130
  return all_metrics
131
 
132
 
133
  # ═══════════════════════════════════════════════════════════════════════════
134
- # TRAINING (uses SentenceTransformer + LoRA)
135
  # ═══════════════════════════════════════════════════════════════════════════
136
 
137
  def augment_with_context(sentences, prob=0.5):
@@ -144,98 +147,116 @@ def augment_with_context(sentences, prob=0.5):
144
  return augmented
145
 
146
 
 
 
 
 
 
 
 
147
  def train_model(token):
148
- from datasets import load_dataset, Dataset
149
- from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer
150
- from sentence_transformers.losses import MultipleNegativesRankingLoss
151
- from sentence_transformers.training_args import SentenceTransformerTrainingArguments
152
- from sentence_transformers.evaluation import InformationRetrievalEvaluator
153
  from peft import LoraConfig, get_peft_model
 
154
 
155
- logger.info(f"\n{'='*60}\n FINE-TUNING WITH LoRA\n{'='*60}")
156
 
157
- lora_config = LoraConfig(
158
- r=LORA_R, lora_alpha=LORA_ALPHA, target_modules=LORA_TARGETS,
159
- lora_dropout=LORA_DROPOUT, bias="none", task_type="FEATURE_EXTRACTION",
 
 
160
  )
161
 
162
- logger.info(f"Loading {MODEL_ID} via SentenceTransformer...")
163
- model = SentenceTransformer(
164
- MODEL_ID, trust_remote_code=True,
165
- model_kwargs={"torch_dtype": torch.bfloat16},
166
  )
167
- model.max_seq_length = 512
168
-
169
- logger.info("Applying LoRA adapter...")
170
- model[0].auto_model = get_peft_model(model[0].auto_model, lora_config)
171
- model[0].auto_model.print_trainable_parameters()
172
-
173
- model.prompts = {"anchor": QUERY_INSTRUCTION, "positive": ""}
174
 
175
  # Dataset
176
  logger.info(f"Loading dataset: {DATASET_ID}")
177
  raw = load_dataset(DATASET_ID, split="train").shuffle(seed=42)
178
- split = raw.train_test_split(test_size=0.05, seed=42)
179
-
180
- sentences = split["train"]["sentence"]
181
- skills = split["train"]["skill"]
182
- aug = augment_with_context(sentences, prob=AUGMENT_PROB)
183
- train_dataset = Dataset.from_dict({"anchor": aug, "positive": skills})
184
- eval_raw = split["test"]
185
- logger.info(f"Train: {len(train_dataset)}, Eval: {len(eval_raw)}")
186
-
187
- loss = MultipleNegativesRankingLoss(model)
188
-
189
- sample = eval_raw.select(range(min(500, len(eval_raw))))
190
- evaluator = InformationRetrievalEvaluator(
191
- queries={str(i): row["sentence"] for i, row in enumerate(sample)},
192
- corpus={s: s for s in set(sample["skill"])},
193
- relevant_docs={str(i): {row["skill"]} for i, row in enumerate(sample)},
194
- name="esco-eval",
195
- score_functions={"cosine": lambda a, b: (a @ b.T)},
 
 
 
196
  )
197
 
198
- push = bool(HUB_MODEL_ID)
199
- args_kwargs = dict(
200
- output_dir=OUTPUT_DIR,
201
- num_train_epochs=EPOCHS,
202
- per_device_train_batch_size=TRAIN_BATCH,
203
- per_device_eval_batch_size=TRAIN_BATCH,
204
- gradient_accumulation_steps=GRAD_ACCUM,
205
- learning_rate=LR,
206
- warmup_steps=WARMUP_STEPS,
207
- bf16=True,
208
- eval_strategy="steps", eval_steps=500,
209
- save_strategy="steps", save_steps=500,
210
- save_total_limit=2,
211
- load_best_model_at_end=True,
212
- metric_for_best_model="esco-eval_cosine_ndcg@10",
213
- logging_steps=50,
214
- gradient_checkpointing=False,
215
- dataloader_pin_memory=False,
216
- push_to_hub=push,
217
- )
218
- if push:
219
- args_kwargs["hub_model_id"] = HUB_MODEL_ID
220
- args_kwargs["hub_strategy"] = "every_save"
221
-
222
- trainer = SentenceTransformerTrainer(
223
- model=model,
224
- args=SentenceTransformerTrainingArguments(**args_kwargs),
225
- train_dataset=train_dataset,
226
- eval_dataset=train_dataset.select(range(500)),
227
- loss=loss, evaluator=evaluator,
228
- )
 
 
 
 
 
 
 
 
 
229
 
230
- logger.info("Starting training...")
231
- trainer.train()
 
 
232
 
233
- model.save_pretrained(f"{OUTPUT_DIR}/final")
234
- if push:
235
  logger.info(f"Pushing to Hub: {HUB_MODEL_ID}")
236
- model.push_to_hub(HUB_MODEL_ID, exist_ok=True)
 
237
 
238
- del model, trainer
 
239
  torch.cuda.empty_cache()
240
  logger.info("Training complete.")
241
 
@@ -256,14 +277,14 @@ def main():
256
  logger.info(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
257
 
258
  # Phase 1: Evaluate base model
259
- base_metrics = evaluate_with_automodel(MODEL_ID, token)
260
 
261
  # Phase 2: Fine-tune with LoRA
262
  train_model(token)
263
 
264
  # Phase 3: Evaluate fine-tuned model
265
- ft_model_id = HUB_MODEL_ID if HUB_MODEL_ID else f"{OUTPUT_DIR}/final"
266
- ft_metrics = evaluate_with_automodel(ft_model_id, token)
267
 
268
  # Summary
269
  all_metrics = base_metrics + ft_metrics
 
1
  """
2
+ NV-Embed-v2: Evaluate base → Fine-tune with LoRA (custom loop) → Evaluate fine-tuned.
3
+ Custom training loop because NVEmbedModel.forward() is incompatible with SentenceTransformerTrainer.
4
  """
5
 
6
  import json
7
  import random
8
  import logging
9
  import os
10
+ import gc
11
  import numpy as np
12
  import pandas as pd
13
  import faiss
 
43
  LORA_ALPHA = 32
44
  LORA_DROPOUT = 0.1
45
  LORA_TARGETS = ["q_proj", "v_proj", "k_proj", "o_proj"]
46
+ TEMPERATURE = 20.0
47
 
48
 
49
  # ═══════════════════════════════════════════════════════════════════════════
50
+ # EVALUATION (uses AutoModel.encode — no gradients needed)
51
  # ═══════════════════════════════════════════════════════════════════════════
52
 
53
+ def evaluate_with_automodel(model_name_or_path, token, method_label=None):
 
54
  from transformers import AutoModel
55
 
56
+ label = method_label or model_name_or_path.split("/")[-1]
57
+ logger.info(f"\n{'='*60}\n EVALUATING: {label}\n{'='*60}")
58
 
59
+ logger.info(f"Loading {model_name_or_path}...")
60
  model = AutoModel.from_pretrained(
61
+ model_name_or_path, trust_remote_code=True, token=token,
62
  torch_dtype=torch.float16, device_map="auto",
63
  )
64
  model.eval()
 
113
 
114
  met = {
115
  "test_set": test_name,
116
+ "method": label,
117
  "matchable_rows": matchable_count,
118
  "accuracy_top1": round(top1/matchable_count, 6),
119
  "accuracy_top3": round(top3/matchable_count, 6),
 
128
  )
129
 
130
  del model
131
+ gc.collect()
132
  torch.cuda.empty_cache()
133
  return all_metrics
134
 
135
 
136
  # ═══════════════════════════════════════════════════════════════════════════
137
+ # TRAINING (custom loop model.encode() with torch.enable_grad)
138
  # ═══════════════════════════════════════════════════════════════════════════
139
 
140
  def augment_with_context(sentences, prob=0.5):
 
147
  return augmented
148
 
149
 
150
+ def mnrl_loss(anchor_emb, positive_emb, temperature=TEMPERATURE):
151
+ """Multiple Negatives Ranking Loss: in-batch contrastive."""
152
+ scores = torch.mm(anchor_emb, positive_emb.t()) * temperature
153
+ labels = torch.arange(scores.size(0), device=scores.device)
154
+ return F.cross_entropy(scores, labels)
155
+
156
+
157
  def train_model(token):
158
+ from transformers import AutoModel
 
 
 
 
159
  from peft import LoraConfig, get_peft_model
160
+ from datasets import load_dataset
161
 
162
+ logger.info(f"\n{'='*60}\n FINE-TUNING WITH LoRA (custom loop)\n{'='*60}")
163
 
164
+ # Load model
165
+ logger.info(f"Loading {MODEL_ID}...")
166
+ model = AutoModel.from_pretrained(
167
+ MODEL_ID, trust_remote_code=True, token=token,
168
+ torch_dtype=torch.bfloat16,
169
  )
170
 
171
+ # Apply LoRA
172
+ lora_config = LoraConfig(
173
+ r=LORA_R, lora_alpha=LORA_ALPHA, target_modules=LORA_TARGETS,
174
+ lora_dropout=LORA_DROPOUT, bias="none",
175
  )
176
+ model = get_peft_model(model, lora_config)
177
+ model.print_trainable_parameters()
178
+ model.cuda()
179
+ model.train()
 
 
 
180
 
181
  # Dataset
182
  logger.info(f"Loading dataset: {DATASET_ID}")
183
  raw = load_dataset(DATASET_ID, split="train").shuffle(seed=42)
184
+ anchors_raw = raw["sentence"]
185
+ positives_raw = raw["skill"]
186
+ logger.info(f"Dataset: {len(anchors_raw)} pairs")
187
+
188
+ logger.info("Augmenting anchors...")
189
+ anchors = augment_with_context(anchors_raw, prob=AUGMENT_PROB)
190
+ positives = positives_raw
191
+
192
+ # Optimizer (only LoRA params)
193
+ trainable_params = [p for p in model.parameters() if p.requires_grad]
194
+ optimizer = torch.optim.AdamW(trainable_params, lr=LR, weight_decay=0.01)
195
+
196
+ total_micro_steps = len(anchors) // TRAIN_BATCH
197
+ total_optim_steps = total_micro_steps // GRAD_ACCUM
198
+ logger.info(f"Micro-steps: {total_micro_steps}, Optimizer steps: {total_optim_steps}")
199
+ logger.info(f"Warmup: {WARMUP_STEPS} steps, LR: {LR}")
200
+
201
+ # LR scheduler with warmup
202
+ from transformers import get_linear_schedule_with_warmup
203
+ scheduler = get_linear_schedule_with_warmup(
204
+ optimizer, num_warmup_steps=WARMUP_STEPS, num_training_steps=total_optim_steps
205
  )
206
 
207
+ # Training loop
208
+ indices = list(range(len(anchors)))
209
+ random.shuffle(indices)
210
+
211
+ optimizer.zero_grad()
212
+ running_loss = 0.0
213
+ micro_step = 0
214
+
215
+ for i in range(0, len(indices) - TRAIN_BATCH + 1, TRAIN_BATCH):
216
+ batch_idx = indices[i:i+TRAIN_BATCH]
217
+ batch_anchors = [anchors[j] for j in batch_idx]
218
+ batch_positives = [positives[j] for j in batch_idx]
219
+
220
+ # Get embeddings WITH gradients via torch.enable_grad()
221
+ with torch.enable_grad(), torch.amp.autocast("cuda", dtype=torch.bfloat16):
222
+ anchor_emb = model.encode(batch_anchors, instruction=QUERY_INSTRUCTION, max_length=512)
223
+ positive_emb = model.encode(batch_positives, instruction="", max_length=512)
224
+
225
+ anchor_emb = F.normalize(anchor_emb, p=2, dim=1)
226
+ positive_emb = F.normalize(positive_emb, p=2, dim=1)
227
+
228
+ loss = mnrl_loss(anchor_emb, positive_emb) / GRAD_ACCUM
229
+
230
+ loss.backward()
231
+ running_loss += loss.item()
232
+ micro_step += 1
233
+
234
+ if micro_step % GRAD_ACCUM == 0:
235
+ torch.nn.utils.clip_grad_norm_(trainable_params, 1.0)
236
+ optimizer.step()
237
+ scheduler.step()
238
+ optimizer.zero_grad()
239
+
240
+ optim_step = micro_step // GRAD_ACCUM
241
+ avg_loss = running_loss
242
+ running_loss = 0.0
243
+
244
+ if optim_step % 50 == 0:
245
+ lr_now = scheduler.get_last_lr()[0]
246
+ logger.info(f" step {optim_step}/{total_optim_steps} loss={avg_loss:.4f} lr={lr_now:.2e}")
247
 
248
+ # Save
249
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
250
+ logger.info(f"Saving to {OUTPUT_DIR}...")
251
+ model.save_pretrained(OUTPUT_DIR)
252
 
253
+ if HUB_MODEL_ID:
 
254
  logger.info(f"Pushing to Hub: {HUB_MODEL_ID}")
255
+ model.push_to_hub(HUB_MODEL_ID, token=token)
256
+ logger.info("Pushed.")
257
 
258
+ del model, optimizer, trainable_params
259
+ gc.collect()
260
  torch.cuda.empty_cache()
261
  logger.info("Training complete.")
262
 
 
277
  logger.info(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
278
 
279
  # Phase 1: Evaluate base model
280
+ base_metrics = evaluate_with_automodel(MODEL_ID, token, "nv-embed-v2")
281
 
282
  # Phase 2: Fine-tune with LoRA
283
  train_model(token)
284
 
285
  # Phase 3: Evaluate fine-tuned model
286
+ ft_source = HUB_MODEL_ID if HUB_MODEL_ID else OUTPUT_DIR
287
+ ft_metrics = evaluate_with_automodel(ft_source, token, "nv-embed-v2-ft")
288
 
289
  # Summary
290
  all_metrics = base_metrics + ft_metrics