YAML Metadata Warning:empty or missing yaml metadata in repo card
Check out the documentation for more information.
RetJEPA
Pure predictive JEPA prototype for evidence retrieval.
The first implementation intentionally trains without contrastive ranking loss:
query -> BERT query encoder -> Transformer predictor -> CLS pool -> z_pred
positive passage -> frozen BERT target/context encoder -> CLS pool -> z_target
loss = mse(normalize(z_pred), normalize(stopgrad(z_target)))
UV Setup
Use Python 3.10-3.12 for the ML stack. UV is the intended package manager so the resolved package set stays stable across machines.
uv sync
The repository includes uv.lock; use it on the GPU machine before training.
Smoke Validation
Before launching real training, run the local tiny-model smoke test:
bash scripts/smoke_tiny.sh
This creates a tiny BERT model and tiny pair/corpus data under .tmp/retjepa_smoke, runs one pure-JEPA training step with gradient accumulation, writes a checkpoint, and evaluates frozen_bert vs retjepa. It does not download BERT or BGE-M3.
Data Format
Training/evaluation pairs are JSONL, CSV, or TSV with at least:
{"query_id": "q1", "passage_id": "p1", "query": "question text", "passage": "positive evidence passage"}
Evaluation can also use a separate retrieval corpus:
{"passage_id": "p1", "passage": "candidate passage text"}
This matters: pair files define positive supervision and qrels, while the corpus file defines the candidate retrieval pool. If no corpus is configured, evaluation falls back to the unique positive passages from the pair file, which is useful only for smoke tests.
Convert a local table into the expected pair/corpus files:
bash scripts/prepare_pairs.sh \
--input raw_pairs.jsonl \
--query-field query \
--passage-field passage \
--query-id-field query_id \
--passage-id-field passage_id \
--output-pairs data/nq_small/train_pairs.jsonl \
--output-corpus data/nq_small/corpus.jsonl
The same converter can read HuggingFace datasets with --dataset, --config-name, and --split when the dataset already exposes query and positive-passage fields.
Full NQ Preparation
For the first full NQ run, use ir_datasets through the dedicated script:
uv sync
bash scripts/prepare_nq_full.sh --output-dir data/nq_full
This writes:
data/nq_full/train_pairs.jsonl
data/nq_full/dev_pairs.jsonl
data/nq_full/corpus.jsonl
By default this uses the DPR Wikipedia-100 NQ retrieval setup:
dpr-w100/natural-questions/train
dpr-w100/natural-questions/dev
dpr-w100
This is the standard DPR-style full-Wikipedia NQ retrieval setup: around 59K train queries, 8.9M train qrels, 6.5K dev queries, 980K dev qrels, and about 21M corpus passages according to the ir_datasets DPR Wiki100 documentation. The preparation script writes positive-only pair files by default using relevance >= 1, so DPR hard negatives are not used by the first pure-prediction objective.
The official Google NQ split is also supported, but it is not the default because the Google-hosted files can fail in some environments:
bash scripts/prepare_nq_full.sh --preset official --output-dir data/nq_official
For a pair-only preparation without rewriting the 28M-passage corpus:
bash scripts/prepare_nq_full.sh --output-dir data/nq_full --skip-corpus
Train
bash scripts/train_pure_jepa.sh configs/pure_jepa_bert_nq_small.yaml
The trainer supports:
- gradient accumulation via
training.grad_accum_steps - bf16/fp16 autocast via
training.precision - TF32 on CUDA via
training.tf32 - gradient clipping via
training.grad_clip - scheduler warmup via
training.schedulerandtraining.warmup_steps - checkpoint cadence via
training.checkpoint_every_steps
Evaluate
bash scripts/eval_pure_jepa.sh \
configs/pure_jepa_bert_nq_small.yaml \
outputs/pure_jepa_bert_nq_small/checkpoint_last.pt
Evaluation compares:
frozen_bert: frozen BERT dense baselineretjepa: predicted evidence embedding from the trained modelbge_m3:BAAI/bge-m3strong external retrieval baseline