orionweller commited on
Commit
04c021a
·
verified ·
1 Parent(s): 621b6e6

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +651 -52
README.md CHANGED
@@ -28,10 +28,11 @@ This model is part of the Ettin suite - the first collection of paired encoder-o
28
  - [Decoder Models](#decoder-models)
29
  - [Cross-Objective Models](#cross-objective-models)
30
  - [Accessing Training Checkpoints](#accessing-training-checkpoints)
31
- - [Usage Examples](#usage-examples)
32
  - [Research Applications](#research-applications)
33
  - [Training Details](#training-details)
34
  - [Model Architecture](#model-architecture)
 
 
35
  - [Citation](#citation)
36
 
37
  ## 📊 Performance Highlights
@@ -53,7 +54,10 @@ This model is part of the Ettin suite - the first collection of paired encoder-o
53
 
54
  ### Installation
55
  ```bash
56
- pip install torch>=1.9.0 transformers>=4.21.0
 
 
 
57
  ```
58
 
59
  ### 30-Second Examples
@@ -198,6 +202,62 @@ model = AutoModelForCausalLM.from_pretrained(
198
 
199
  This checkpoint availability enables detailed analysis of training dynamics, loss curves, and capability emergence across the complete 2T token training process.
200
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  ## Usage Examples
202
 
203
  ### Encoder: Masked Language Modeling
@@ -274,74 +334,613 @@ print(generated)
274
  </details>
275
 
276
 
277
- ## 🔬 Research Applications
278
 
279
- ### What Makes Ettin Unique
 
280
 
281
- Ettin provides the first **controlled comparison** of encoder vs. decoder architectures:
 
282
 
283
- - **Identical Training Data**: Same 2T token mixture across all models
284
- - **Matched Architectures**: Only attention patterns and objectives differ
285
- - **Open Everything**: Training data, model weights, and batch-level training order
286
- - **Multiple Scales**: Fair comparison from 17M to 1B parameters
287
- - **250+ Checkpoints**: Complete training trajectory analysis
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
 
289
- ### Key Research Findings
290
 
291
- 1. **Architecture Specialization Persists**:
292
- - Encoders excel at classification/retrieval even vs. larger decoders
293
- - Decoders excel at generation even vs. larger encoders
294
- - A 400M encoder beats a 1B decoder on MNLI (89.2 vs 88.2)
295
 
296
- 2. **Cross-Training Limitations**:
297
- - Converting decoder→encoder or encoder→decoder underperforms
298
- - 50B tokens of continued training insufficient to close gaps
299
- - Native training objective remains superior
 
 
 
300
 
301
- 3. **Scaling Insights**:
302
- - Performance gaps between architectures widen with size
303
- - Decoder-from-encoder adaptation scales particularly poorly
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
 
305
- ### Use Cases for Researchers
 
306
 
307
- - **Architecture Studies**: Compare encoder vs decoder capabilities fairly
308
- - **Training Dynamics**: Analyze 250+ checkpoints with batch-level data ordering
309
- - **Scaling Laws**: Study how architectural advantages change with scale
310
- - **Transfer Learning**: Investigate cross-objective training effectiveness
311
- - **Replication Studies**: First open replication of ModernBERT training recipe
312
 
313
- ### Reproducibility
 
314
 
315
- All training artifacts are publicly available:
316
- - Training data with exact batch ordering
317
- - Model checkpoints every 8.5B tokens
318
- - Complete hyperparameter configurations
319
- - Training code and evaluation scripts
320
 
321
- ## Training Details
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
322
 
323
- **Data:** High-quality mixture including DCLM, Dolma v1.7, scientific papers, code, and curated sources totaling 2T+ tokens
 
 
 
 
 
 
 
 
 
 
 
 
324
 
325
- **Architecture:** Transformer with RoPE, GLU activations, and prenorm layers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
326
 
327
- **Training Phases:**
328
- - **Pre-training**: 1.7T tokens with diverse data mixture
329
- - **Mid-training**: 250B tokens with higher-quality filtered data and context extension to 8K
330
- - **Decay phase**: 100B tokens with premium data sources
 
 
 
 
 
 
 
 
 
331
 
332
- **Key Features:**
333
- - Context length: Up to 8K tokens
334
- - Vocabulary: 50,368 tokens (ModernBERT tokenizer)
335
- - Deep but efficient architectures following MobileLLM principles
336
 
337
- ## Model Architecture
 
338
 
339
- | Parameter | 17M | 32M | 68M | 150M | 400M | 1B |
340
- |:----------|:----|:----|:----|:-----|:-----|:---|
341
- | Layers | 7 | 10 | 19 | 22 | 28 | 28 |
342
- | Hidden Size | 256 | 384 | 512 | 768 | 1024 | 1792 |
343
- | Intermediate Size | 384 | 576 | 768 | 1152 | 2624 | 3840 |
344
- | Attention Heads | 4 | 6 | 8 | 12 | 16 | 28 |
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
 
346
  ## Citation
347
 
 
28
  - [Decoder Models](#decoder-models)
29
  - [Cross-Objective Models](#cross-objective-models)
30
  - [Accessing Training Checkpoints](#accessing-training-checkpoints)
 
31
  - [Research Applications](#research-applications)
32
  - [Training Details](#training-details)
33
  - [Model Architecture](#model-architecture)
34
+ - [Usage Examples](#usage-examples)
35
+ - [Fine-tuning Examples](#fine-tuning-examples)
36
  - [Citation](#citation)
37
 
38
  ## 📊 Performance Highlights
 
54
 
55
  ### Installation
56
  ```bash
57
+ pip install torch>=1.9.0
58
+ # until the new pip release, install from main to use decoders (transformers>=4.54.X will contain it)
59
+ # encoders work with transformers>=4.48.0
60
+ pip install git+https://github.com/huggingface/transformers.git
61
  ```
62
 
63
  ### 30-Second Examples
 
202
 
203
  This checkpoint availability enables detailed analysis of training dynamics, loss curves, and capability emergence across the complete 2T token training process.
204
 
205
+
206
+ ## 🔬 Research Applications
207
+
208
+ ### What Makes Ettin Unique
209
+
210
+ Ettin provides the first **controlled comparison** of encoder vs. decoder architectures:
211
+
212
+ - **Identical Training Data**: Same 2T token mixture across all models
213
+ - **Matched Architectures**: Only attention patterns and objectives differ
214
+ - **Open Everything**: Training data, model weights, and batch-level training order
215
+ - **Multiple Scales**: Fair comparison from 17M to 1B parameters
216
+ - **250+ Checkpoints**: Complete training trajectory analysis
217
+
218
+ ### Use Cases for Researchers
219
+
220
+ - **Architecture Studies**: Compare encoder vs decoder capabilities fairly
221
+ - **Training Dynamics**: Analyze 250+ checkpoints with batch-level data ordering
222
+ - **Scaling Laws**: Study how architectural advantages change with scale
223
+ - **Transfer Learning**: Investigate cross-objective training effectiveness
224
+ - **Replication Studies**: First open replication of ModernBERT training recipe
225
+
226
+ ### Reproducibility
227
+
228
+ All training artifacts are publicly available:
229
+ - Training data with exact batch ordering
230
+ - Model checkpoints every 8.5B tokens
231
+ - Complete hyperparameter configurations
232
+ - Training code and evaluation scripts
233
+
234
+ ## Training Details
235
+
236
+ **Data:** High-quality mixture including DCLM, Dolma v1.7, scientific papers, code, and curated sources totaling 2T+ tokens
237
+
238
+ **Architecture:** Transformer with RoPE, GLU activations, and prenorm layers
239
+
240
+ **Training Phases:**
241
+ - **Pre-training**: 1.7T tokens with diverse data mixture
242
+ - **Mid-training**: 250B tokens with higher-quality filtered data and context extension to 8K
243
+ - **Decay phase**: 100B tokens with premium data sources
244
+
245
+ **Key Features:**
246
+ - Context length: Up to 8K tokens
247
+ - Vocabulary: 50,368 tokens (ModernBERT tokenizer)
248
+ - Deep but efficient architectures following MobileLLM principles
249
+
250
+ ## Model Architecture
251
+
252
+ | Parameter | 17M | 32M | 68M | 150M | 400M | 1B |
253
+ |:----------|:----|:----|:----|:-----|:-----|:---|
254
+ | Layers | 7 | 10 | 19 | 22 | 28 | 28 |
255
+ | Hidden Size | 256 | 384 | 512 | 768 | 1024 | 1792 |
256
+ | Intermediate Size | 384 | 576 | 768 | 1152 | 2624 | 3840 |
257
+ | Attention Heads | 4 | 6 | 8 | 12 | 16 | 28 |
258
+
259
+
260
+
261
  ## Usage Examples
262
 
263
  ### Encoder: Masked Language Modeling
 
334
  </details>
335
 
336
 
337
+ ## Fine-tuning Examples
338
 
339
+ ### Encoders
340
+ <details><summary>Click to see how to finetune this into a dense embedding model using Sentence Transformers</summary>
341
 
342
+ ```python
343
+ import argparse
344
 
345
+ from datasets import load_dataset
346
+ from sentence_transformers import (
347
+ SentenceTransformer,
348
+ SentenceTransformerTrainer,
349
+ SentenceTransformerTrainingArguments,
350
+ )
351
+ from sentence_transformers.evaluation import TripletEvaluator
352
+ from sentence_transformers.losses import CachedMultipleNegativesRankingLoss
353
+ from sentence_transformers.training_args import BatchSamplers
354
+
355
+ def main():
356
+ # parse the lr & model name
357
+ parser = argparse.ArgumentParser()
358
+ parser.add_argument("--lr", type=float, default=8e-5)
359
+ parser.add_argument("--model_name", type=str, default="jhu-clsp/ettin-encoder-150m")
360
+ args = parser.parse_args()
361
+ lr = args.lr
362
+ model_name = args.model_name
363
+ model_shortname = model_name.split("/")[-1]
364
+
365
+ # 1. Load a model to finetune
366
+ model = SentenceTransformer(model_name)
367
+
368
+ # 2. Load a dataset to finetune on
369
+ dataset = load_dataset(
370
+ "sentence-transformers/msmarco-co-condenser-margin-mse-sym-mnrl-mean-v1",
371
+ "triplet-hard",
372
+ split="train",
373
+ )
374
+ dataset_dict = dataset.train_test_split(test_size=1_000, seed=12)
375
+ train_dataset = dataset_dict["train"].select(range(1_250_000))
376
+ eval_dataset = dataset_dict["test"]
377
+
378
+ # 3. Define a loss function
379
+ loss = CachedMultipleNegativesRankingLoss(model, mini_batch_size=16) # Increase mini_batch_size if you have enough VRAM
380
+
381
+ run_name = f"{model_shortname}-DPR-{lr}"
382
+ # 4. (Optional) Specify training arguments
383
+ args = SentenceTransformerTrainingArguments(
384
+ # Required parameter:
385
+ output_dir=f"output/{model_shortname}/{run_name}",
386
+ # Optional training parameters:
387
+ num_train_epochs=1,
388
+ per_device_train_batch_size=512,
389
+ per_device_eval_batch_size=512,
390
+ warmup_ratio=0.05,
391
+ fp16=False, # Set to False if GPU can't handle FP16
392
+ bf16=True, # Set to True if GPU supports BF16
393
+ batch_sampler=BatchSamplers.NO_DUPLICATES, # (Cached)MultipleNegativesRankingLoss benefits from no duplicates
394
+ learning_rate=lr,
395
+ # Optional tracking/debugging parameters:
396
+ save_strategy="steps",
397
+ save_steps=500,
398
+ save_total_limit=2,
399
+ logging_steps=500,
400
+ run_name=run_name, # Used in `wandb`, `tensorboard`, `neptune`, etc. if installed
401
+ )
402
+
403
+ # 5. (Optional) Create an evaluator & evaluate the base model
404
+ dev_evaluator = TripletEvaluator(
405
+ anchors=eval_dataset["query"],
406
+ positives=eval_dataset["positive"],
407
+ negatives=eval_dataset["negative"],
408
+ name="msmarco-co-condenser-dev",
409
+ )
410
+ dev_evaluator(model)
411
+
412
+ # 6. Create a trainer & train
413
+ trainer = SentenceTransformerTrainer(
414
+ model=model,
415
+ args=args,
416
+ train_dataset=train_dataset,
417
+ eval_dataset=eval_dataset,
418
+ loss=loss,
419
+ evaluator=dev_evaluator,
420
+ )
421
+ trainer.train()
422
+
423
+ # 7. (Optional) Evaluate the trained model on the evaluator after training
424
+ dev_evaluator(model)
425
+
426
+ # 8. Save the model
427
+ model.save_pretrained(f"output/{model_shortname}/{run_name}/final")
428
+
429
+ # 9. (Optional) Push it to the Hugging Face Hub
430
+ model.push_to_hub(run_name, private=False)
431
+
432
+ if __name__ == "__main__":
433
+ main()
434
+ ```
435
+ </details>
436
 
 
437
 
438
+ <details><summary>Click to see how to finetune this into a multi-vector embedding model with PyLate</summary>
 
 
 
439
 
440
+ ```python
441
+ from datasets import load_dataset
442
+ from pylate import losses, models, utils
443
+ from sentence_transformers import (
444
+ SentenceTransformerTrainer,
445
+ SentenceTransformerTrainingArguments,
446
+ )
447
 
448
+ def main():
449
+ # Load the datasets required for knowledge distillation (train, queries, documents)
450
+ train = load_dataset(
451
+ path="lightonai/ms-marco-en-bge",
452
+ name="train",
453
+ )
454
+
455
+ queries = load_dataset(
456
+ path="lightonai/ms-marco-en-bge",
457
+ name="queries",
458
+ )
459
+
460
+ documents = load_dataset(
461
+ path="lightonai/ms-marco-en-bge",
462
+ name="documents",
463
+ )
464
+
465
+ # Set the transformation to load the documents/queries texts using the corresponding ids on the fly
466
+ train.set_transform(
467
+ utils.KDProcessing(queries=queries, documents=documents).transform,
468
+ )
469
+
470
+ # Define the base model, training parameters, and output directory
471
+ num_train_epochs = 1
472
+ lr = 8e-5
473
+ batch_size = 16
474
+ accum_steps = 1
475
+ model_name = "jhu-clsp/ettin-encoder-150m"
476
+ model_shortname = model_name.split("/")[-1]
477
+
478
+ # Set the run name for logging and output directory
479
+ run_name = f"{model_shortname}-colbert-KD-{lr}"
480
+ output_dir = f"output/{model_shortname}/{run_name}"
481
+
482
+ # Initialize the ColBERT model from the base model
483
+ model = models.ColBERT(model_name_or_path=model_name)
484
+
485
+ # Configure the training arguments (e.g., epochs, batch size, learning rate)
486
+ args = SentenceTransformerTrainingArguments(
487
+ output_dir=output_dir,
488
+ num_train_epochs=num_train_epochs,
489
+ per_device_train_batch_size=batch_size,
490
+ fp16=False, # Set to False if you get an error that your GPU can't run on FP16
491
+ bf16=True, # Set to True if you have a GPU that supports BF16
492
+ run_name=run_name,
493
+ logging_steps=10,
494
+ learning_rate=lr,
495
+ gradient_accumulation_steps=accum_steps,
496
+ warmup_ratio=0.05,
497
+ )
498
+
499
+ # Use the Distillation loss function for training
500
+ train_loss = losses.Distillation(model=model)
501
+
502
+ # Initialize the trainer
503
+ trainer = SentenceTransformerTrainer(
504
+ model=model,
505
+ args=args,
506
+ train_dataset=train,
507
+ loss=train_loss,
508
+ data_collator=utils.ColBERTCollator(tokenize_fn=model.tokenize),
509
+ )
510
+
511
+ # Start the training process
512
+ trainer.train()
513
+
514
+ model.save_pretrained(f"{output_dir}/final")
515
+
516
+ if __name__ == "__main__":
517
+ main()
518
 
519
+ ```
520
+ </details>
521
 
522
+ <details><summary>Click to see how to finetune this into a sparse retrieval model using Sentence Transformers</summary>
 
 
 
 
523
 
524
+ ```python
525
+ import logging
526
 
527
+ from datasets import load_dataset
 
 
 
 
528
 
529
+ from sentence_transformers import (
530
+ SparseEncoder,
531
+ SparseEncoderModelCardData,
532
+ SparseEncoderTrainer,
533
+ SparseEncoderTrainingArguments,
534
+ )
535
+ from sentence_transformers.sparse_encoder.evaluation import SparseNanoBEIREvaluator
536
+ from sentence_transformers.sparse_encoder.losses import SparseMultipleNegativesRankingLoss, SpladeLoss
537
+ from sentence_transformers.training_args import BatchSamplers
538
+
539
+ logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)
540
+
541
+ # 1. Load a model to finetune with 2. (Optional) model card data
542
+ model = SparseEncoder(
543
+ "jhu-clsp/ettin-encoder-150m",
544
+ model_card_data=SparseEncoderModelCardData(
545
+ language="en",
546
+ license="apache-2.0",
547
+ )
548
+ )
549
 
550
+ # 3. Load a dataset to finetune on
551
+ full_dataset = load_dataset("sentence-transformers/natural-questions", split="train").select(range(100_000))
552
+ dataset_dict = full_dataset.train_test_split(test_size=1_000, seed=12)
553
+ train_dataset = dataset_dict["train"]
554
+ eval_dataset = dataset_dict["test"]
555
+
556
+ # 4. Define a loss function
557
+ loss = SpladeLoss(
558
+ model=model,
559
+ loss=SparseMultipleNegativesRankingLoss(model=model),
560
+ query_regularizer_weight=5e-5,
561
+ document_regularizer_weight=3e-5,
562
+ )
563
 
564
+ # 5. (Optional) Specify training arguments
565
+ run_name = "splade-distilbert-base-uncased-nq"
566
+ args = SparseEncoderTrainingArguments(
567
+ # Required parameter:
568
+ output_dir=f"models/{run_name}",
569
+ # Optional training parameters:
570
+ num_train_epochs=1,
571
+ per_device_train_batch_size=16,
572
+ per_device_eval_batch_size=16,
573
+ learning_rate=2e-5,
574
+ warmup_ratio=0.1,
575
+ fp16=True, # Set to False if you get an error that your GPU can't run on FP16
576
+ bf16=False, # Set to True if you have a GPU that supports BF16
577
+ batch_sampler=BatchSamplers.NO_DUPLICATES, # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
578
+ # Optional tracking/debugging parameters:
579
+ eval_strategy="steps",
580
+ eval_steps=1000,
581
+ save_strategy="steps",
582
+ save_steps=1000,
583
+ save_total_limit=2,
584
+ logging_steps=200,
585
+ run_name=run_name, # Will be used in W&B if `wandb` is installed
586
+ )
587
 
588
+ # 6. (Optional) Create an evaluator & evaluate the base model
589
+ dev_evaluator = SparseNanoBEIREvaluator(dataset_names=["msmarco", "nfcorpus", "nq"], batch_size=16)
590
+
591
+ # 7. Create a trainer & train
592
+ trainer = SparseEncoderTrainer(
593
+ model=model,
594
+ args=args,
595
+ train_dataset=train_dataset,
596
+ eval_dataset=eval_dataset,
597
+ loss=loss,
598
+ evaluator=dev_evaluator,
599
+ )
600
+ trainer.train()
601
 
602
+ # 8. Evaluate the model performance again after training
603
+ dev_evaluator(model)
 
 
604
 
605
+ # 9. Save the trained model
606
+ model.save_pretrained(f"models/{run_name}/final")
607
 
608
+ # 10. (Optional) Push it to the Hugging Face Hub
609
+ model.push_to_hub(run_name)
610
+
611
+ ```
612
+ </details>
613
+
614
+ <details><summary>Click to see how to finetune this into a reranker model using Sentence Transformers</summary>
615
+
616
+ ```python
617
+ import logging
618
+ import traceback
619
+
620
+ import torch
621
+ from datasets import load_dataset
622
+
623
+ from sentence_transformers import SentenceTransformer
624
+ from sentence_transformers.cross_encoder import (
625
+ CrossEncoder,
626
+ CrossEncoderModelCardData,
627
+ CrossEncoderTrainer,
628
+ CrossEncoderTrainingArguments,
629
+ )
630
+ from sentence_transformers.cross_encoder.evaluation import (
631
+ CrossEncoderNanoBEIREvaluator,
632
+ CrossEncoderRerankingEvaluator,
633
+ )
634
+ from sentence_transformers.cross_encoder.losses import BinaryCrossEntropyLoss
635
+ from sentence_transformers.evaluation import SequentialEvaluator
636
+ from sentence_transformers.util import mine_hard_negatives
637
+
638
+ # Set the log level to INFO to get more information
639
+ logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)
640
+
641
+
642
+ def main():
643
+ model_name = "jhu-clsp/ettin-encoder-150m"
644
+
645
+ train_batch_size = 64
646
+ num_epochs = 1
647
+ num_hard_negatives = 5 # How many hard negatives should be mined for each question-answer pair
648
+
649
+ # 1a. Load a model to finetune with 1b. (Optional) model card data
650
+ model = CrossEncoder(
651
+ model_name,
652
+ model_card_data=CrossEncoderModelCardData(
653
+ language="en",
654
+ license="apache-2.0",
655
+ ),
656
+ )
657
+ print("Model max length:", model.max_length)
658
+ print("Model num labels:", model.num_labels)
659
+
660
+ # 2a. Load the GooAQ dataset: https://huggingface.co/datasets/sentence-transformers/gooaq
661
+ logging.info("Read the gooaq training dataset")
662
+ full_dataset = load_dataset("sentence-transformers/gooaq", split="train").select(range(100_000))
663
+ dataset_dict = full_dataset.train_test_split(test_size=1_000, seed=12)
664
+ train_dataset = dataset_dict["train"]
665
+ eval_dataset = dataset_dict["test"]
666
+ logging.info(train_dataset)
667
+ logging.info(eval_dataset)
668
+
669
+ # 2b. Modify our training dataset to include hard negatives using a very efficient embedding model
670
+ embedding_model = SentenceTransformer("sentence-transformers/static-retrieval-mrl-en-v1", device="cpu")
671
+ hard_train_dataset = mine_hard_negatives(
672
+ train_dataset,
673
+ embedding_model,
674
+ num_negatives=num_hard_negatives, # How many negatives per question-answer pair
675
+ margin=0, # Similarity between query and negative samples should be x lower than query-positive similarity
676
+ range_min=0, # Skip the x most similar samples
677
+ range_max=100, # Consider only the x most similar samples
678
+ sampling_strategy="top", # Sample the top negatives from the range
679
+ batch_size=4096, # Use a batch size of 4096 for the embedding model
680
+ output_format="labeled-pair", # The output format is (query, passage, label), as required by BinaryCrossEntropyLoss
681
+ use_faiss=True,
682
+ )
683
+ logging.info(hard_train_dataset)
684
+
685
+ # 2c. (Optionally) Save the hard training dataset to disk
686
+ # hard_train_dataset.save_to_disk("gooaq-hard-train")
687
+ # Load again with:
688
+ # hard_train_dataset = load_from_disk("gooaq-hard-train")
689
+
690
+ # 3. Define our training loss.
691
+ # pos_weight is recommended to be set as the ratio between positives to negatives, a.k.a. `num_hard_negatives`
692
+ loss = BinaryCrossEntropyLoss(model=model, pos_weight=torch.tensor(num_hard_negatives))
693
+
694
+ # 4a. Define evaluators. We use the CrossEncoderNanoBEIREvaluator, which is a light-weight evaluator for English reranking
695
+ nano_beir_evaluator = CrossEncoderNanoBEIREvaluator(
696
+ dataset_names=["msmarco", "nfcorpus", "nq"],
697
+ batch_size=train_batch_size,
698
+ )
699
+
700
+ # 4b. Define a reranking evaluator by mining hard negatives given query-answer pairs
701
+ # We include the positive answer in the list of negatives, so the evaluator can use the performance of the
702
+ # embedding model as a baseline.
703
+ hard_eval_dataset = mine_hard_negatives(
704
+ eval_dataset,
705
+ embedding_model,
706
+ corpus=full_dataset["answer"], # Use the full dataset as the corpus
707
+ num_negatives=30, # How many documents to rerank
708
+ batch_size=4096,
709
+ include_positives=True,
710
+ output_format="n-tuple",
711
+ use_faiss=True,
712
+ )
713
+ logging.info(hard_eval_dataset)
714
+ reranking_evaluator = CrossEncoderRerankingEvaluator(
715
+ samples=[
716
+ {
717
+ "query": sample["question"],
718
+ "positive": [sample["answer"]],
719
+ "documents": [sample[column_name] for column_name in hard_eval_dataset.column_names[2:]],
720
+ }
721
+ for sample in hard_eval_dataset
722
+ ],
723
+ batch_size=train_batch_size,
724
+ name="gooaq-dev",
725
+ # Realistic setting: only rerank the positives that the retriever found
726
+ # Set to True to rerank *all* positives
727
+ always_rerank_positives=False,
728
+ )
729
+
730
+ # 4c. Combine the evaluators & run the base model on them
731
+ evaluator = SequentialEvaluator([reranking_evaluator, nano_beir_evaluator])
732
+ evaluator(model)
733
+
734
+ # 5. Define the training arguments
735
+ short_model_name = model_name if "/" not in model_name else model_name.split("/")[-1]
736
+ run_name = f"reranker-{short_model_name}-gooaq-bce"
737
+ args = CrossEncoderTrainingArguments(
738
+ # Required parameter:
739
+ output_dir=f"models/{run_name}",
740
+ # Optional training parameters:
741
+ num_train_epochs=num_epochs,
742
+ per_device_train_batch_size=train_batch_size,
743
+ per_device_eval_batch_size=train_batch_size,
744
+ learning_rate=2e-5,
745
+ warmup_ratio=0.1,
746
+ fp16=False, # Set to False if you get an error that your GPU can't run on FP16
747
+ bf16=True, # Set to True if you have a GPU that supports BF16
748
+ dataloader_num_workers=4,
749
+ load_best_model_at_end=True,
750
+ metric_for_best_model="eval_gooaq-dev_ndcg@10",
751
+ # Optional tracking/debugging parameters:
752
+ eval_strategy="steps",
753
+ eval_steps=1000,
754
+ save_strategy="steps",
755
+ save_steps=1000,
756
+ save_total_limit=2,
757
+ logging_steps=200,
758
+ logging_first_step=True,
759
+ run_name=run_name, # Will be used in W&B if `wandb` is installed
760
+ seed=12,
761
+ )
762
+
763
+ # 6. Create the trainer & start training
764
+ trainer = CrossEncoderTrainer(
765
+ model=model,
766
+ args=args,
767
+ train_dataset=hard_train_dataset,
768
+ loss=loss,
769
+ evaluator=evaluator,
770
+ )
771
+ trainer.train()
772
+
773
+ # 7. Evaluate the final model, useful to include these in the model card
774
+ evaluator(model)
775
+
776
+ # 8. Save the final model
777
+ final_output_dir = f"models/{run_name}/final"
778
+ model.save_pretrained(final_output_dir)
779
+
780
+ # 9. (Optional) save the model to the Hugging Face Hub!
781
+ # It is recommended to run `huggingface-cli login` to log into your Hugging Face account first
782
+ try:
783
+ model.push_to_hub(run_name)
784
+ except Exception:
785
+ logging.error(
786
+ f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run "
787
+ f"`huggingface-cli login`, followed by loading the model using `model = CrossEncoder({final_output_dir!r})` "
788
+ f"and saving it using `model.push_to_hub('{run_name}')`."
789
+ )
790
+
791
+
792
+ if __name__ == "__main__":
793
+ main()
794
+
795
+ ```
796
+ </details>
797
+
798
+ ### Decoders
799
+
800
+ <details>
801
+ <summary>Click to expand decoder training code</summary>
802
+
803
+ # Full training
804
+ ```bash
805
+ python trl/scripts/sft.py \
806
+ --model_name_or_path jhu-clsp/ettin-decoder-17m \
807
+ --dataset_name trl-lib/Capybara \
808
+ --learning_rate 2.0e-5 \
809
+ --num_train_epochs 1 \
810
+ --packing \
811
+ --per_device_train_batch_size 2 \
812
+ --gradient_accumulation_steps 8 \
813
+ --gradient_checkpointing \
814
+ --eos_token '<|im_end|>' \
815
+ --eval_strategy steps \
816
+ --eval_steps 100 \
817
+ --output_dir ettin-decoder-17m \
818
+ --push_to_hub
819
+ ```
820
+
821
+ # LoRA
822
+ ```bash
823
+ python trl/scripts/sft.py \
824
+ --model_name_or_path jhu-clsp/ettin-decoder-17m \
825
+ --dataset_name trl-lib/Capybara \
826
+ --learning_rate 2.0e-4 \
827
+ --num_train_epochs 1 \
828
+ --packing \
829
+ --per_device_train_batch_size 2 \
830
+ --gradient_accumulation_steps 8 \
831
+ --gradient_checkpointing \
832
+ --eos_token '<|im_end|>' \
833
+ --eval_strategy steps \
834
+ --eval_steps 100 \
835
+ --use_peft \
836
+ --lora_r 32 \
837
+ --lora_alpha 16 \
838
+ --output_dir ettin-decoder-17m \
839
+ --push_to_hub
840
+ ```
841
+
842
+ with `sft.py`:
843
+ ```python
844
+ import argparse
845
+
846
+ from datasets import load_dataset
847
+ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
848
+ from transformers.models.auto.modeling_auto import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES
849
+
850
+ from trl import (
851
+ ModelConfig,
852
+ ScriptArguments,
853
+ SFTConfig,
854
+ SFTTrainer,
855
+ TrlParser,
856
+ clone_chat_template,
857
+ get_kbit_device_map,
858
+ get_peft_config,
859
+ get_quantization_config,
860
+ )
861
+
862
+
863
+ def main(script_args, training_args, model_args):
864
+ ################
865
+ # Model init kwargs & Tokenizer
866
+ ################
867
+ quantization_config = get_quantization_config(model_args)
868
+ model_kwargs = dict(
869
+ revision=model_args.model_revision,
870
+ trust_remote_code=model_args.trust_remote_code,
871
+ attn_implementation=model_args.attn_implementation,
872
+ torch_dtype=model_args.torch_dtype,
873
+ use_cache=False if training_args.gradient_checkpointing else True,
874
+ device_map=get_kbit_device_map() if quantization_config is not None else None,
875
+ quantization_config=quantization_config,
876
+ )
877
+
878
+ # Create model
879
+ config = AutoConfig.from_pretrained(model_args.model_name_or_path)
880
+ valid_image_text_architectures = MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.values()
881
+
882
+ if config.architectures and any(arch in valid_image_text_architectures for arch in config.architectures):
883
+ from transformers import AutoModelForImageTextToText
884
+
885
+ model_kwargs.pop("use_cache", None) # Image models do not support cache
886
+ model = AutoModelForImageTextToText.from_pretrained(model_args.model_name_or_path, **model_kwargs)
887
+ else:
888
+ model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, **model_kwargs)
889
+
890
+ # Create tokenizer
891
+ tokenizer = AutoTokenizer.from_pretrained(
892
+ model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, use_fast=True
893
+ )
894
+
895
+ # Set default chat template if needed
896
+ if tokenizer.chat_template is None:
897
+ # TODO: source should be passed as an argument
898
+ model, tokenizer = clone_chat_template(model, tokenizer, "Qwen/Qwen3-0.6B")
899
+
900
+ ################
901
+ # Dataset
902
+ ################
903
+ dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
904
+
905
+ ################
906
+ # Training
907
+ ################
908
+ trainer = SFTTrainer(
909
+ model=model,
910
+ args=training_args,
911
+ train_dataset=dataset[script_args.dataset_train_split],
912
+ eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
913
+ processing_class=tokenizer,
914
+ peft_config=get_peft_config(model_args),
915
+ )
916
+
917
+ trainer.train()
918
+
919
+ # Save and push to hub
920
+ trainer.save_model(training_args.output_dir)
921
+ if training_args.push_to_hub:
922
+ trainer.push_to_hub(dataset_name=script_args.dataset_name)
923
+
924
+
925
+ def make_parser(subparsers: argparse._SubParsersAction = None):
926
+ dataclass_types = (ScriptArguments, SFTConfig, ModelConfig)
927
+ if subparsers is not None:
928
+ parser = subparsers.add_parser("sft", help="Run the SFT training script", dataclass_types=dataclass_types)
929
+ else:
930
+ parser = TrlParser(dataclass_types)
931
+ return parser
932
+
933
+
934
+ if __name__ == "__main__":
935
+ parser = make_parser()
936
+ # When using the trl cli, this script may be run with additional arguments, corresponding accelerate arguments.
937
+ # To ensure that their parsing does not interfere with the script arguments, parse the arguments with
938
+ # `return_remaining_strings=True`, then ignore the remaining strings.
939
+ script_args, training_args, model_args, _ = parser.parse_args_and_config(return_remaining_strings=True)
940
+ main(script_args, training_args, model_args)
941
+
942
+ ```
943
+ </details>
944
 
945
  ## Citation
946