YAML Metadata Warning:empty or missing yaml metadata in repo card
Check out the documentation for more information.
JEPA Text World Model
A next-word prediction world model based on Joint Embedding Predictive Architecture (JEPA). Instead of predicting tokens directly, this model predicts next concepts as latent embeddings, enabling internal deliberation and multi-branch reasoning before committing to output tokens.
Key Innovation
Traditional language models commit to a single token at each step via softmax. JEPA Text World Model:
- Predicts the next concept in a continuous latent space (multiple token hypotheses)
- Can unroll the world model for multiple steps in latent space ("thinking internally")
- Uses a separate Talker decoder to convert latent concepts back to tokens
- Can try multiple reasoning branches and select the best one
Architecture
Input Tokens: [The][cat][sat][on]...
β
ββββββββββββββββββββββββ
β Context Encoder ββ trainable ΞΈ
β (Transformer) β
ββββββββββββββββββββββββ
β latent vectors h_ctx[B,T,D]
ββββββββββββββββββββββββββββ
β World Model Predictor ββ trainable Ο
β (Narrow Transformer) β Predicts next latent embeddings
ββββββββββββββββββββββββββββ
β
βββββββββββββββββββββββ
β Scaled Cosine Dist ββ Loss: compare to EMA target
β + SIGReg anti-collapseβ
βββββββββββββββββββββββ
β
ββββββββββββββββββββββββ
β Talker Decoder ββ Latent β Token
β (Cross-attention) β
ββββββββββββββββββββββββ
β
Output Tokens: [the][mat]...
Quick Start
Installation
pip install torch transformers datasets trackio
Training (3 phases)
# Phase 1: Pretrain (next-token)
# Phase 2: SST (JEPA latent prediction)
# Phase 3: Talker (latent β token)
python train.py \
--encoder_dim 768 \
--encoder_layers 12 \
--predictor_layers 4 \
--context_length 256 \
--num_pretrain_epochs 3 \
--num_sst_epochs 3 \
--batch_size 32 \
--output_dir ./jepa_output \
--push_to_hub \
--hub_model_id your-username/jepa-text-wm
Inference
# Basic generation
python inference.py --model_path ./jepa_output --prompt "The cat "
# Multi-branch thinking (4 branches)
python inference.py --model_path ./jepa_output --prompt "The cat " --mode multi_branch --num_branches 4
# Deep thinking (8 branches, 4 latent steps)
python inference.py --model_path ./jepa_output --prompt "The cat " --mode multi_branch --num_branches 8 --latent_steps 4
# Analyze latent chain
python inference.py --model_path ./jepa_output --prompt "The cat " --mode latent_analysis
# Interactive demo
python inference.py --model_path ./jepa_output --interactive
Training Pipeline
Phase 1: Pretraining
Standard next-token prediction with tied embeddings. Transitions from token space to latent space smoothly.
Phase 2: Self-Supervised Training (SST)
Switches to JEPA objective:
- Scaled Cosine Distance (k=4) between predicted and EMA target latents
- SIGReg (Ξ»=0.1) anti-collapse regularizer from LeWorldModel
- EMA target encoder (momentum 0.98β1.0 cosine schedule)
Phase 3: Talker Training
Trains the decoder to reconstruct tokens from the Reasoner's latent chains. Freezes the Reasoner during this phase.
Multi-Branch Reasoning
The model's key capability: try multiple thought trajectories internally:
- Latent Chain Unrolling:
h_{t+1} = P(h_{β€t}) β h_{t+2} = P(h_{β€t+1}) β ... - Branch Exploration: Generate N different latent trajectories from the same context
- Branch Evaluation: Each branch produces candidate tokens via Talker
- Branch Selection: Pick the branch with lowest reconstruction error
This is like having the model "think" through multiple continuations before choosing one.
Configuration
| Parameter | Default | Description |
|---|---|---|
encoder_dim |
768 | Encoder hidden dimension |
encoder_layers |
12 | Encoder transformer layers |
predictor_layers |
4 | Predictor (narrow transformer) layers |
pred_dim |
768 | Predictor hidden dimension |
context_length |
256 | Max sequence length |
context_ratio |
0.7 | Portion used as context (rest is target) |
ema_momentum |
0.98 | Target encoder momentum |
cosine_scale_k |
4.0 | Scaled cosine distance factor |
sigreg_lambda |
0.1 | Anti-collapse weight |
sigreg_M |
1024 | Random projection directions |
Papers & References
- JEPA-Reasoner (arxiv:2512.19171) β Decoupling latent reasoning from token generation. The core text-JEPA architecture.
- LeWorldModel (arxiv:2603.19312) β Stable end-to-end JEPA with SIGReg. The anti-collapse mechanism.
- COCONUT (arxiv:2412.06769) β Chain of Continuous Thought. Latent reasoning in standard transformers.
- V-JEPA (arxiv:2404.08471) β Original JEPA for video. The prediction-in-latent-space paradigm.
Files
jepa-text-world-model/
βββ src/
β βββ model.py # Core JEPA architecture
β βββ attention.py # QK-Norm + causal masking
β βββ normalization.py # RMSNorm, L2Norm, HybridNorm
β βββ sigreg.py # SIGReg anti-collapse regularizer
β βββ __init__.py
βββ train.py # 3-phase training script
βββ inference.py # Inference, multi-branch, demo
βββ architecture.md # Detailed architecture docs
βββ README.md