YAML Metadata Warning:empty or missing yaml metadata in repo card

Check out the documentation for more information.

JEPA Text World Model

A next-word prediction world model based on Joint Embedding Predictive Architecture (JEPA). Instead of predicting tokens directly, this model predicts next concepts as latent embeddings, enabling internal deliberation and multi-branch reasoning before committing to output tokens.

Key Innovation

Traditional language models commit to a single token at each step via softmax. JEPA Text World Model:

  1. Predicts the next concept in a continuous latent space (multiple token hypotheses)
  2. Can unroll the world model for multiple steps in latent space ("thinking internally")
  3. Uses a separate Talker decoder to convert latent concepts back to tokens
  4. Can try multiple reasoning branches and select the best one

Architecture

Input Tokens:  [The][cat][sat][on]...
     ↓
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚  Context Encoder     │← trainable ΞΈ
β”‚  (Transformer)       β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
       ↓ latent vectors h_ctx[B,T,D]
     β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
     β”‚  World Model Predictor   │← trainable Ο†
     β”‚  (Narrow Transformer)    β”‚     Predicts next latent embeddings
     β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
              ↓
     β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
     β”‚  Scaled Cosine Dist  │← Loss: compare to EMA target
     β”‚  + SIGReg anti-collapseβ”‚
     β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
              ↓
     β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
     β”‚  Talker Decoder      │← Latent β†’ Token
     β”‚  (Cross-attention)   β”‚
     β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
              ↓
Output Tokens: [the][mat]...

Quick Start

Installation

pip install torch transformers datasets trackio

Training (3 phases)

# Phase 1: Pretrain (next-token)
# Phase 2: SST (JEPA latent prediction)
# Phase 3: Talker (latent β†’ token)
python train.py \
  --encoder_dim 768 \
  --encoder_layers 12 \
  --predictor_layers 4 \
  --context_length 256 \
  --num_pretrain_epochs 3 \
  --num_sst_epochs 3 \
  --batch_size 32 \
  --output_dir ./jepa_output \
  --push_to_hub \
  --hub_model_id your-username/jepa-text-wm

Inference

# Basic generation
python inference.py --model_path ./jepa_output --prompt "The cat "

# Multi-branch thinking (4 branches)
python inference.py --model_path ./jepa_output --prompt "The cat " --mode multi_branch --num_branches 4

# Deep thinking (8 branches, 4 latent steps)
python inference.py --model_path ./jepa_output --prompt "The cat " --mode multi_branch --num_branches 8 --latent_steps 4

# Analyze latent chain
python inference.py --model_path ./jepa_output --prompt "The cat " --mode latent_analysis

# Interactive demo
python inference.py --model_path ./jepa_output --interactive

Training Pipeline

Phase 1: Pretraining

Standard next-token prediction with tied embeddings. Transitions from token space to latent space smoothly.

Phase 2: Self-Supervised Training (SST)

Switches to JEPA objective:

  • Scaled Cosine Distance (k=4) between predicted and EMA target latents
  • SIGReg (Ξ»=0.1) anti-collapse regularizer from LeWorldModel
  • EMA target encoder (momentum 0.98β†’1.0 cosine schedule)

Phase 3: Talker Training

Trains the decoder to reconstruct tokens from the Reasoner's latent chains. Freezes the Reasoner during this phase.

Multi-Branch Reasoning

The model's key capability: try multiple thought trajectories internally:

  1. Latent Chain Unrolling: h_{t+1} = P(h_{≀t}) β†’ h_{t+2} = P(h_{≀t+1}) β†’ ...
  2. Branch Exploration: Generate N different latent trajectories from the same context
  3. Branch Evaluation: Each branch produces candidate tokens via Talker
  4. Branch Selection: Pick the branch with lowest reconstruction error

This is like having the model "think" through multiple continuations before choosing one.

Configuration

Parameter Default Description
encoder_dim 768 Encoder hidden dimension
encoder_layers 12 Encoder transformer layers
predictor_layers 4 Predictor (narrow transformer) layers
pred_dim 768 Predictor hidden dimension
context_length 256 Max sequence length
context_ratio 0.7 Portion used as context (rest is target)
ema_momentum 0.98 Target encoder momentum
cosine_scale_k 4.0 Scaled cosine distance factor
sigreg_lambda 0.1 Anti-collapse weight
sigreg_M 1024 Random projection directions

Papers & References

  • JEPA-Reasoner (arxiv:2512.19171) β€” Decoupling latent reasoning from token generation. The core text-JEPA architecture.
  • LeWorldModel (arxiv:2603.19312) β€” Stable end-to-end JEPA with SIGReg. The anti-collapse mechanism.
  • COCONUT (arxiv:2412.06769) β€” Chain of Continuous Thought. Latent reasoning in standard transformers.
  • V-JEPA (arxiv:2404.08471) β€” Original JEPA for video. The prediction-in-latent-space paradigm.

Files

jepa-text-world-model/
β”œβ”€β”€ src/
β”‚   β”œβ”€β”€ model.py          # Core JEPA architecture
β”‚   β”œβ”€β”€ attention.py       # QK-Norm + causal masking
β”‚   β”œβ”€β”€ normalization.py   # RMSNorm, L2Norm, HybridNorm
β”‚   β”œβ”€β”€ sigreg.py          # SIGReg anti-collapse regularizer
β”‚   └── __init__.py
β”œβ”€β”€ train.py               # 3-phase training script
β”œβ”€β”€ inference.py           # Inference, multi-branch, demo
β”œβ”€β”€ architecture.md        # Detailed architecture docs
└── README.md
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Papers for Atlas-0/jepa-text-world-model