jepa_llm_prototypes / README.md
wassemgtk's picture
Update README.md
7818e90 verified
metadata
license: mit

JEPA-Style LLM Prototypes

Making decoder-only transformers predict state consequences instead of tokens.

What's This?

Three approaches to convert a standard LLM into a world model that predicts "what happens next" given a state and action β€” like JEPA but for language models.

Files

File Description GPU Time
jepa_llm_prototypes.ipynb All three options in one notebook β€” best for comparing ~30 min
jepa_option1_sentence_encoder.ipynb Simplest approach using pre-trained sentence embeddings ~10 min
jepa_option2_llm_hidden_states.ipynb Uses GPT-2 hidden states as state space ~15 min

Quick Start

  1. Open any notebook in Google Colab
  2. Set runtime to GPU (Runtime β†’ Change runtime type β†’ H100)
  3. Run all cells
  4. Watch the model learn to predict state transitions

The Core Idea

Normal LLM:     tokens β†’ transformer β†’ next token
JEPA-style:     (state, action) β†’ transformer β†’ next state embedding

Instead of predicting words, the model predicts what the world looks like after an action.

Three Approaches

Option 1: Sentence Encoder (Simplest)

  • Uses all-MiniLM-L6-v2 for embeddings
  • Trains only a small predictor network
  • Best for: quick testing, limited GPU

Option 2: LLM Hidden States (Medium)

  • Uses GPT-2's internal representations
  • Trains projection + predictor heads
  • Best for: better accuracy, still fast

Option 3: Autoencoder (Most Powerful)

  • Learns domain-specific state embeddings
  • Trains encoder + decoder + predictor
  • Best for: production, domain adaptation

Example

# Input
state = "Document is in draft status with 2 sections"
action = "User submits for review"

# Model predicts
next_state = "Document is pending review"  # via embedding similarity

Requirements

  • Python 3.8+
  • PyTorch
  • Transformers
  • Sentence-Transformers (Option 1)
  • GPU recommended (runs on CPU but slow)

All dependencies install automatically in the notebooks.

Next Steps

  • Swap synthetic data for real enterprise workflow logs
  • Scale up base model (Llama, Qwen, Palmyra)
  • Add multi-step trajectory prediction
  • Integrate with planning/search algorithms

Experimental code β€” have fun breaking it.

Coauthors: Writer Agent & OpenCode