File size: 2,379 Bytes
4452d64
 
 
62fabf8
4452d64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7818e90
4452d64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7818e90
4452d64
 
 
 
 
62fabf8
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
---
license: mit
---

# JEPA-Style LLM Prototypes

Making decoder-only transformers predict state consequences instead of tokens.

## What's This?

Three approaches to convert a standard LLM into a world model that predicts "what happens next" given a state and action β€” like JEPA but for language models.

## Files

| File | Description | GPU Time |
|------|-------------|----------|
| `jepa_llm_prototypes.ipynb` | **All three options in one notebook** β€” best for comparing | ~30 min |
| `jepa_option1_sentence_encoder.ipynb` | Simplest approach using pre-trained sentence embeddings | ~10 min |
| `jepa_option2_llm_hidden_states.ipynb` | Uses GPT-2 hidden states as state space | ~15 min |

## Quick Start

1. Open any notebook in [Google Colab](https://colab.research.google.com/)
2. Set runtime to **GPU** (Runtime β†’ Change runtime type β†’ H100)
3. Run all cells
4. Watch the model learn to predict state transitions

## The Core Idea

```
Normal LLM:     tokens β†’ transformer β†’ next token
JEPA-style:     (state, action) β†’ transformer β†’ next state embedding
```

Instead of predicting words, the model predicts what the world looks like after an action.

## Three Approaches

**Option 1: Sentence Encoder** (Simplest)
- Uses `all-MiniLM-L6-v2` for embeddings
- Trains only a small predictor network
- Best for: quick testing, limited GPU

**Option 2: LLM Hidden States** (Medium)
- Uses GPT-2's internal representations
- Trains projection + predictor heads
- Best for: better accuracy, still fast

**Option 3: Autoencoder** (Most Powerful)
- Learns domain-specific state embeddings
- Trains encoder + decoder + predictor
- Best for: production, domain adaptation

## Example

```python
# Input
state = "Document is in draft status with 2 sections"
action = "User submits for review"

# Model predicts
next_state = "Document is pending review"  # via embedding similarity
```

## Requirements

- Python 3.8+
- PyTorch
- Transformers
- Sentence-Transformers (Option 1)
- GPU recommended (runs on CPU but slow)

All dependencies install automatically in the notebooks.

## Next Steps

- Swap synthetic data for real enterprise workflow logs
- Scale up base model (Llama, Qwen, Palmyra)
- Add multi-step trajectory prediction
- Integrate with planning/search algorithms

---

*Experimental code β€” have fun breaking it.*

**Coauthors:** Writer Agent & OpenCode