wassemgtk commited on
Commit
bbba84f
·
verified ·
1 Parent(s): c3dc12a
jepa_llm_prototypes.ipynb ADDED
@@ -0,0 +1,1258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# 🧠 JEPA-Style LLM Prototypes\n",
8
+ "\n",
9
+ "## Making Decoder-Only LLMs Predict State Consequences Instead of Tokens\n",
10
+ "\n",
11
+ "This notebook implements three approaches to convert a standard LLM into a JEPA-style world model:\n",
12
+ "\n",
13
+ "1. **Option 1:** Sentence Encoder Approach (Simplest)\n",
14
+ "2. **Option 2:** LLM Hidden States Approach (Medium)\n",
15
+ "3. **Option 3:** Full Autoencoder Approach (Most Powerful)\n",
16
+ "\n",
17
+ "---"
18
+ ]
19
+ },
20
+ {
21
+ "cell_type": "markdown",
22
+ "metadata": {},
23
+ "source": [
24
+ "## 📦 Setup & Installation"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "execution_count": null,
30
+ "metadata": {},
31
+ "outputs": [],
32
+ "source": [
33
+ "# Install required packages\n",
34
+ "!pip install -q transformers accelerate bitsandbytes sentence-transformers datasets torch matplotlib tqdm"
35
+ ]
36
+ },
37
+ {
38
+ "cell_type": "code",
39
+ "execution_count": null,
40
+ "metadata": {},
41
+ "outputs": [],
42
+ "source": [
43
+ "import torch\n",
44
+ "import torch.nn as nn\n",
45
+ "import torch.nn.functional as F\n",
46
+ "from torch.utils.data import Dataset, DataLoader\n",
47
+ "from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM\n",
48
+ "from sentence_transformers import SentenceTransformer\n",
49
+ "import numpy as np\n",
50
+ "import matplotlib.pyplot as plt\n",
51
+ "from tqdm.auto import tqdm\n",
52
+ "import random\n",
53
+ "\n",
54
+ "# Check GPU\n",
55
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
56
+ "print(f\"Using device: {device}\")\n",
57
+ "if torch.cuda.is_available():\n",
58
+ " print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n",
59
+ " print(f\"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB\")"
60
+ ]
61
+ },
62
+ {
63
+ "cell_type": "markdown",
64
+ "metadata": {},
65
+ "source": [
66
+ "## 📊 Create Synthetic Dataset\n",
67
+ "\n",
68
+ "We'll create a simple \"enterprise workflow\" dataset with:\n",
69
+ "- **States:** Document/workflow status descriptions\n",
70
+ "- **Actions:** User actions\n",
71
+ "- **Next States:** Resulting state after action\n",
72
+ "\n",
73
+ "This simulates learning the \"physics\" of your enterprise domain."
74
+ ]
75
+ },
76
+ {
77
+ "cell_type": "code",
78
+ "execution_count": null,
79
+ "metadata": {},
80
+ "outputs": [],
81
+ "source": [
82
+ "class EnterpriseWorkflowDataset(Dataset):\n",
83
+ " \"\"\"\n",
84
+ " Synthetic dataset simulating enterprise workflow state transitions.\n",
85
+ " \n",
86
+ " Each sample is a (state, action, next_state) triplet.\n",
87
+ " The model learns to predict next_state given state + action.\n",
88
+ " \"\"\"\n",
89
+ " \n",
90
+ " def __init__(self, num_samples=1000, seed=42):\n",
91
+ " random.seed(seed)\n",
92
+ " self.samples = self._generate_samples(num_samples)\n",
93
+ " \n",
94
+ " def _generate_samples(self, num_samples):\n",
95
+ " samples = []\n",
96
+ " \n",
97
+ " # Document workflow transitions\n",
98
+ " doc_transitions = [\n",
99
+ " # (current_state, action, next_state)\n",
100
+ " (\"Document is in draft status with 0 sections\", \"User creates new section\", \"Document is in draft status with 1 section\"),\n",
101
+ " (\"Document is in draft status with 1 section\", \"User creates new section\", \"Document is in draft status with 2 sections\"),\n",
102
+ " (\"Document is in draft status with 2 sections\", \"User creates new section\", \"Document is in draft status with 3 sections\"),\n",
103
+ " (\"Document is in draft status with 3 sections\", \"User submits for review\", \"Document is pending review with 3 sections\"),\n",
104
+ " (\"Document is pending review with 3 sections\", \"Reviewer approves document\", \"Document is approved and published\"),\n",
105
+ " (\"Document is pending review with 3 sections\", \"Reviewer requests changes\", \"Document is in revision with 3 sections\"),\n",
106
+ " (\"Document is in revision with 3 sections\", \"User makes requested changes\", \"Document is pending review with 3 sections\"),\n",
107
+ " (\"Document is approved and published\", \"User archives document\", \"Document is archived\"),\n",
108
+ " (\"Document is in draft status with 1 section\", \"User deletes section\", \"Document is in draft status with 0 sections\"),\n",
109
+ " (\"Document is in draft status with 2 sections\", \"User deletes section\", \"Document is in draft status with 1 section\"),\n",
110
+ " ]\n",
111
+ " \n",
112
+ " # Project workflow transitions\n",
113
+ " project_transitions = [\n",
114
+ " (\"Project is in planning phase with 0 tasks\", \"Manager adds task\", \"Project is in planning phase with 1 task\"),\n",
115
+ " (\"Project is in planning phase with 1 task\", \"Manager adds task\", \"Project is in planning phase with 2 tasks\"),\n",
116
+ " (\"Project is in planning phase with 2 tasks\", \"Manager starts project\", \"Project is active with 2 tasks and 0 completed\"),\n",
117
+ " (\"Project is active with 2 tasks and 0 completed\", \"Team completes task\", \"Project is active with 2 tasks and 1 completed\"),\n",
118
+ " (\"Project is active with 2 tasks and 1 completed\", \"Team completes task\", \"Project is completed with all tasks done\"),\n",
119
+ " (\"Project is active with 2 tasks and 0 completed\", \"Manager pauses project\", \"Project is on hold with 2 tasks\"),\n",
120
+ " (\"Project is on hold with 2 tasks\", \"Manager resumes project\", \"Project is active with 2 tasks and 0 completed\"),\n",
121
+ " (\"Project is completed with all tasks done\", \"Manager closes project\", \"Project is archived\"),\n",
122
+ " ]\n",
123
+ " \n",
124
+ " # User account transitions\n",
125
+ " account_transitions = [\n",
126
+ " (\"User account is new with basic permissions\", \"Admin grants editor role\", \"User account is active with editor permissions\"),\n",
127
+ " (\"User account is active with editor permissions\", \"Admin grants admin role\", \"User account is active with admin permissions\"),\n",
128
+ " (\"User account is active with admin permissions\", \"User requests deactivation\", \"User account is pending deactivation\"),\n",
129
+ " (\"User account is pending deactivation\", \"Admin confirms deactivation\", \"User account is deactivated\"),\n",
130
+ " (\"User account is active with editor permissions\", \"Security flags suspicious activity\", \"User account is locked pending review\"),\n",
131
+ " (\"User account is locked pending review\", \"Security clears account\", \"User account is active with editor permissions\"),\n",
132
+ " ]\n",
133
+ " \n",
134
+ " # Inventory transitions\n",
135
+ " inventory_transitions = [\n",
136
+ " (\"Inventory has 100 units in stock\", \"Customer orders 10 units\", \"Inventory has 90 units in stock\"),\n",
137
+ " (\"Inventory has 90 units in stock\", \"Customer orders 20 units\", \"Inventory has 70 units in stock\"),\n",
138
+ " (\"Inventory has 70 units in stock\", \"Supplier delivers 50 units\", \"Inventory has 120 units in stock\"),\n",
139
+ " (\"Inventory has 20 units in stock\", \"System triggers low stock alert\", \"Inventory has 20 units with reorder pending\"),\n",
140
+ " (\"Inventory has 20 units with reorder pending\", \"Supplier delivers 100 units\", \"Inventory has 120 units in stock\"),\n",
141
+ " (\"Inventory has 0 units in stock\", \"Customer attempts order\", \"Inventory has 0 units with backorder created\"),\n",
142
+ " ]\n",
143
+ " \n",
144
+ " all_transitions = doc_transitions + project_transitions + account_transitions + inventory_transitions\n",
145
+ " \n",
146
+ " # Generate samples by randomly selecting transitions\n",
147
+ " for _ in range(num_samples):\n",
148
+ " state, action, next_state = random.choice(all_transitions)\n",
149
+ " samples.append({\n",
150
+ " 'state': state,\n",
151
+ " 'action': action,\n",
152
+ " 'next_state': next_state\n",
153
+ " })\n",
154
+ " \n",
155
+ " return samples\n",
156
+ " \n",
157
+ " def __len__(self):\n",
158
+ " return len(self.samples)\n",
159
+ " \n",
160
+ " def __getitem__(self, idx):\n",
161
+ " return self.samples[idx]\n",
162
+ "\n",
163
+ "\n",
164
+ "# Create dataset\n",
165
+ "train_dataset = EnterpriseWorkflowDataset(num_samples=2000, seed=42)\n",
166
+ "val_dataset = EnterpriseWorkflowDataset(num_samples=200, seed=123)\n",
167
+ "\n",
168
+ "print(f\"Training samples: {len(train_dataset)}\")\n",
169
+ "print(f\"Validation samples: {len(val_dataset)}\")\n",
170
+ "print(\"\\nExample sample:\")\n",
171
+ "print(f\" State: {train_dataset[0]['state']}\")\n",
172
+ "print(f\" Action: {train_dataset[0]['action']}\")\n",
173
+ "print(f\" Next State: {train_dataset[0]['next_state']}\")"
174
+ ]
175
+ },
176
+ {
177
+ "cell_type": "markdown",
178
+ "metadata": {},
179
+ "source": [
180
+ "---\n",
181
+ "\n",
182
+ "# 🔵 Option 1: Sentence Encoder Approach (Simplest)\n",
183
+ "\n",
184
+ "Uses a pre-trained sentence encoder (like `all-MiniLM-L6-v2`) for state embeddings.\n",
185
+ "\n",
186
+ "**Pros:**\n",
187
+ "- Fastest to train\n",
188
+ "- No need to train encoder\n",
189
+ "- Small memory footprint\n",
190
+ "\n",
191
+ "**Cons:**\n",
192
+ "- Limited by pre-trained encoder's representation\n",
193
+ "- May not capture domain-specific nuances"
194
+ ]
195
+ },
196
+ {
197
+ "cell_type": "code",
198
+ "execution_count": null,
199
+ "metadata": {},
200
+ "outputs": [],
201
+ "source": [
202
+ "class Option1_SentenceEncoderJEPA(nn.Module):\n",
203
+ " \"\"\"\n",
204
+ " JEPA-style world model using pre-trained sentence encoder.\n",
205
+ " \n",
206
+ " Architecture:\n",
207
+ " - State Encoder: Pre-trained SentenceTransformer (frozen)\n",
208
+ " - Predictor: Small transformer that predicts next state embedding\n",
209
+ " \"\"\"\n",
210
+ " \n",
211
+ " def __init__(\n",
212
+ " self,\n",
213
+ " sentence_model_name='all-MiniLM-L6-v2',\n",
214
+ " hidden_dim=256,\n",
215
+ " num_layers=2,\n",
216
+ " num_heads=4,\n",
217
+ " dropout=0.1\n",
218
+ " ):\n",
219
+ " super().__init__()\n",
220
+ " \n",
221
+ " # Pre-trained sentence encoder (frozen)\n",
222
+ " self.sentence_encoder = SentenceTransformer(sentence_model_name)\n",
223
+ " self.sentence_encoder.requires_grad_(False) # Freeze\n",
224
+ " \n",
225
+ " # Get embedding dimension from sentence encoder\n",
226
+ " self.embed_dim = self.sentence_encoder.get_sentence_embedding_dimension()\n",
227
+ " \n",
228
+ " # Project state and action to hidden dim\n",
229
+ " self.state_proj = nn.Linear(self.embed_dim, hidden_dim)\n",
230
+ " self.action_proj = nn.Linear(self.embed_dim, hidden_dim)\n",
231
+ " \n",
232
+ " # Combine state + action\n",
233
+ " self.combine = nn.Sequential(\n",
234
+ " nn.Linear(hidden_dim * 2, hidden_dim),\n",
235
+ " nn.LayerNorm(hidden_dim),\n",
236
+ " nn.GELU()\n",
237
+ " )\n",
238
+ " \n",
239
+ " # Transformer predictor\n",
240
+ " encoder_layer = nn.TransformerEncoderLayer(\n",
241
+ " d_model=hidden_dim,\n",
242
+ " nhead=num_heads,\n",
243
+ " dim_feedforward=hidden_dim * 4,\n",
244
+ " dropout=dropout,\n",
245
+ " activation='gelu',\n",
246
+ " batch_first=True\n",
247
+ " )\n",
248
+ " self.predictor = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)\n",
249
+ " \n",
250
+ " # Output projection to state embedding space\n",
251
+ " self.output_proj = nn.Sequential(\n",
252
+ " nn.Linear(hidden_dim, hidden_dim),\n",
253
+ " nn.GELU(),\n",
254
+ " nn.Linear(hidden_dim, self.embed_dim)\n",
255
+ " )\n",
256
+ " \n",
257
+ " def encode_text(self, texts):\n",
258
+ " \"\"\"Encode text to embeddings using sentence encoder.\"\"\"\n",
259
+ " with torch.no_grad():\n",
260
+ " embeddings = self.sentence_encoder.encode(\n",
261
+ " texts, \n",
262
+ " convert_to_tensor=True,\n",
263
+ " show_progress_bar=False\n",
264
+ " )\n",
265
+ " return embeddings\n",
266
+ " \n",
267
+ " def forward(self, state_texts, action_texts):\n",
268
+ " \"\"\"\n",
269
+ " Predict next state embedding given current state and action.\n",
270
+ " \n",
271
+ " Args:\n",
272
+ " state_texts: List of state descriptions\n",
273
+ " action_texts: List of action descriptions\n",
274
+ " \n",
275
+ " Returns:\n",
276
+ " predicted_next_state: Predicted next state embeddings [B, embed_dim]\n",
277
+ " \"\"\"\n",
278
+ " # Encode state and action\n",
279
+ " state_emb = self.encode_text(state_texts) # [B, embed_dim]\n",
280
+ " action_emb = self.encode_text(action_texts) # [B, embed_dim]\n",
281
+ " \n",
282
+ " # Project to hidden dim\n",
283
+ " state_h = self.state_proj(state_emb) # [B, hidden_dim]\n",
284
+ " action_h = self.action_proj(action_emb) # [B, hidden_dim]\n",
285
+ " \n",
286
+ " # Combine\n",
287
+ " combined = self.combine(torch.cat([state_h, action_h], dim=-1)) # [B, hidden_dim]\n",
288
+ " \n",
289
+ " # Add sequence dimension for transformer\n",
290
+ " combined = combined.unsqueeze(1) # [B, 1, hidden_dim]\n",
291
+ " \n",
292
+ " # Predict through transformer\n",
293
+ " predicted = self.predictor(combined) # [B, 1, hidden_dim]\n",
294
+ " \n",
295
+ " # Project to state embedding space\n",
296
+ " predicted_next_state = self.output_proj(predicted.squeeze(1)) # [B, embed_dim]\n",
297
+ " \n",
298
+ " return predicted_next_state\n",
299
+ " \n",
300
+ " def get_target_embedding(self, next_state_texts):\n",
301
+ " \"\"\"Get target embedding for loss computation.\"\"\"\n",
302
+ " return self.encode_text(next_state_texts)\n",
303
+ "\n",
304
+ "\n",
305
+ "# Create model\n",
306
+ "print(\"Creating Option 1 model...\")\n",
307
+ "model_opt1 = Option1_SentenceEncoderJEPA(\n",
308
+ " sentence_model_name='all-MiniLM-L6-v2',\n",
309
+ " hidden_dim=256,\n",
310
+ " num_layers=2,\n",
311
+ " num_heads=4\n",
312
+ ").to(device)\n",
313
+ "\n",
314
+ "# Count parameters\n",
315
+ "trainable_params = sum(p.numel() for p in model_opt1.parameters() if p.requires_grad)\n",
316
+ "total_params = sum(p.numel() for p in model_opt1.parameters())\n",
317
+ "print(f\"Trainable parameters: {trainable_params:,}\")\n",
318
+ "print(f\"Total parameters: {total_params:,}\")"
319
+ ]
320
+ },
321
+ {
322
+ "cell_type": "code",
323
+ "execution_count": null,
324
+ "metadata": {},
325
+ "outputs": [],
326
+ "source": [
327
+ "def train_option1(model, train_dataset, val_dataset, epochs=10, batch_size=32, lr=1e-3):\n",
328
+ " \"\"\"\n",
329
+ " Training loop for Option 1 model.\n",
330
+ " \n",
331
+ " Loss: Cosine similarity between predicted and target state embeddings.\n",
332
+ " \"\"\"\n",
333
+ " model.train()\n",
334
+ " optimizer = torch.optim.AdamW(model.parameters(), lr=lr)\n",
335
+ " scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)\n",
336
+ " \n",
337
+ " train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
338
+ " val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)\n",
339
+ " \n",
340
+ " history = {'train_loss': [], 'val_loss': [], 'val_similarity': []}\n",
341
+ " \n",
342
+ " for epoch in range(epochs):\n",
343
+ " # Training\n",
344
+ " model.train()\n",
345
+ " train_losses = []\n",
346
+ " \n",
347
+ " pbar = tqdm(train_loader, desc=f\"Epoch {epoch+1}/{epochs}\")\n",
348
+ " for batch in pbar:\n",
349
+ " states = batch['state']\n",
350
+ " actions = batch['action']\n",
351
+ " next_states = batch['next_state']\n",
352
+ " \n",
353
+ " # Forward pass\n",
354
+ " predicted = model(states, actions)\n",
355
+ " target = model.get_target_embedding(next_states)\n",
356
+ " \n",
357
+ " # Cosine similarity loss (1 - similarity to minimize)\n",
358
+ " similarity = F.cosine_similarity(predicted, target, dim=-1)\n",
359
+ " loss = (1 - similarity).mean()\n",
360
+ " \n",
361
+ " # Backward pass\n",
362
+ " optimizer.zero_grad()\n",
363
+ " loss.backward()\n",
364
+ " torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
365
+ " optimizer.step()\n",
366
+ " \n",
367
+ " train_losses.append(loss.item())\n",
368
+ " pbar.set_postfix({'loss': f'{loss.item():.4f}'})\n",
369
+ " \n",
370
+ " scheduler.step()\n",
371
+ " \n",
372
+ " # Validation\n",
373
+ " model.eval()\n",
374
+ " val_losses = []\n",
375
+ " val_similarities = []\n",
376
+ " \n",
377
+ " with torch.no_grad():\n",
378
+ " for batch in val_loader:\n",
379
+ " states = batch['state']\n",
380
+ " actions = batch['action']\n",
381
+ " next_states = batch['next_state']\n",
382
+ " \n",
383
+ " predicted = model(states, actions)\n",
384
+ " target = model.get_target_embedding(next_states)\n",
385
+ " \n",
386
+ " similarity = F.cosine_similarity(predicted, target, dim=-1)\n",
387
+ " loss = (1 - similarity).mean()\n",
388
+ " \n",
389
+ " val_losses.append(loss.item())\n",
390
+ " val_similarities.append(similarity.mean().item())\n",
391
+ " \n",
392
+ " # Record history\n",
393
+ " history['train_loss'].append(np.mean(train_losses))\n",
394
+ " history['val_loss'].append(np.mean(val_losses))\n",
395
+ " history['val_similarity'].append(np.mean(val_similarities))\n",
396
+ " \n",
397
+ " print(f\"Epoch {epoch+1}: Train Loss={np.mean(train_losses):.4f}, \"\n",
398
+ " f\"Val Loss={np.mean(val_losses):.4f}, Val Similarity={np.mean(val_similarities):.4f}\")\n",
399
+ " \n",
400
+ " return history\n",
401
+ "\n",
402
+ "\n",
403
+ "# Train the model\n",
404
+ "print(\"\\n\" + \"=\"*50)\n",
405
+ "print(\"Training Option 1: Sentence Encoder JEPA\")\n",
406
+ "print(\"=\"*50)\n",
407
+ "history_opt1 = train_option1(model_opt1, train_dataset, val_dataset, epochs=10, batch_size=32)"
408
+ ]
409
+ },
410
+ {
411
+ "cell_type": "code",
412
+ "execution_count": null,
413
+ "metadata": {},
414
+ "outputs": [],
415
+ "source": [
416
+ "def test_model(model, test_samples):\n",
417
+ " \"\"\"Test the model on specific examples.\"\"\"\n",
418
+ " model.eval()\n",
419
+ " \n",
420
+ " print(\"\\n\" + \"=\"*60)\n",
421
+ " print(\"Model Predictions\")\n",
422
+ " print(\"=\"*60)\n",
423
+ " \n",
424
+ " for sample in test_samples:\n",
425
+ " state = sample['state']\n",
426
+ " action = sample['action']\n",
427
+ " actual_next = sample['next_state']\n",
428
+ " \n",
429
+ " with torch.no_grad():\n",
430
+ " # Get prediction\n",
431
+ " predicted_emb = model([state], [action])\n",
432
+ " actual_emb = model.get_target_embedding([actual_next])\n",
433
+ " \n",
434
+ " # Compute similarity\n",
435
+ " similarity = F.cosine_similarity(predicted_emb, actual_emb, dim=-1).item()\n",
436
+ " \n",
437
+ " print(f\"\\nState: {state}\")\n",
438
+ " print(f\"Action: {action}\")\n",
439
+ " print(f\"Actual Next State: {actual_next}\")\n",
440
+ " print(f\"Prediction Similarity: {similarity:.4f} {'✓' if similarity > 0.8 else '✗'}\")\n",
441
+ "\n",
442
+ "\n",
443
+ "# Test on a few examples\n",
444
+ "test_samples = [\n",
445
+ " {'state': 'Document is in draft status with 2 sections', 'action': 'User creates new section', 'next_state': 'Document is in draft status with 3 sections'},\n",
446
+ " {'state': 'Project is active with 2 tasks and 0 completed', 'action': 'Team completes task', 'next_state': 'Project is active with 2 tasks and 1 completed'},\n",
447
+ " {'state': 'Inventory has 100 units in stock', 'action': 'Customer orders 10 units', 'next_state': 'Inventory has 90 units in stock'},\n",
448
+ "]\n",
449
+ "\n",
450
+ "test_model(model_opt1, test_samples)"
451
+ ]
452
+ },
453
+ {
454
+ "cell_type": "markdown",
455
+ "metadata": {},
456
+ "source": [
457
+ "---\n",
458
+ "\n",
459
+ "# 🟢 Option 2: LLM Hidden States Approach (Medium)\n",
460
+ "\n",
461
+ "Uses a small LLM's hidden states for state representations.\n",
462
+ "\n",
463
+ "**Pros:**\n",
464
+ "- Better language understanding\n",
465
+ "- Can fine-tune encoder\n",
466
+ "- More expressive representations\n",
467
+ "\n",
468
+ "**Cons:**\n",
469
+ "- Slower than Option 1\n",
470
+ "- Requires more memory"
471
+ ]
472
+ },
473
+ {
474
+ "cell_type": "code",
475
+ "execution_count": null,
476
+ "metadata": {},
477
+ "outputs": [],
478
+ "source": [
479
+ "class Option2_LLMHiddenStateJEPA(nn.Module):\n",
480
+ " \"\"\"\n",
481
+ " JEPA-style world model using LLM hidden states.\n",
482
+ " \n",
483
+ " Architecture:\n",
484
+ " - State Encoder: Small LLM (GPT-2 or similar) + pooling\n",
485
+ " - Predictor: MLP that predicts next state embedding\n",
486
+ " \"\"\"\n",
487
+ " \n",
488
+ " def __init__(\n",
489
+ " self,\n",
490
+ " model_name='gpt2', # Small model for Colab\n",
491
+ " state_dim=512,\n",
492
+ " freeze_encoder=True\n",
493
+ " ):\n",
494
+ " super().__init__()\n",
495
+ " \n",
496
+ " # Load tokenizer and model\n",
497
+ " self.tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
498
+ " self.tokenizer.pad_token = self.tokenizer.eos_token\n",
499
+ " \n",
500
+ " self.encoder = AutoModel.from_pretrained(model_name)\n",
501
+ " self.hidden_size = self.encoder.config.hidden_size\n",
502
+ " \n",
503
+ " if freeze_encoder:\n",
504
+ " for param in self.encoder.parameters():\n",
505
+ " param.requires_grad = False\n",
506
+ " \n",
507
+ " self.state_dim = state_dim\n",
508
+ " \n",
509
+ " # State projection (from LLM hidden to state space)\n",
510
+ " self.state_proj = nn.Sequential(\n",
511
+ " nn.Linear(self.hidden_size, state_dim),\n",
512
+ " nn.LayerNorm(state_dim),\n",
513
+ " nn.GELU()\n",
514
+ " )\n",
515
+ " \n",
516
+ " # Action projection\n",
517
+ " self.action_proj = nn.Sequential(\n",
518
+ " nn.Linear(self.hidden_size, state_dim),\n",
519
+ " nn.LayerNorm(state_dim),\n",
520
+ " nn.GELU()\n",
521
+ " )\n",
522
+ " \n",
523
+ " # Predictor: takes state + action, outputs next state\n",
524
+ " self.predictor = nn.Sequential(\n",
525
+ " nn.Linear(state_dim * 2, state_dim * 2),\n",
526
+ " nn.LayerNorm(state_dim * 2),\n",
527
+ " nn.GELU(),\n",
528
+ " nn.Dropout(0.1),\n",
529
+ " nn.Linear(state_dim * 2, state_dim * 2),\n",
530
+ " nn.LayerNorm(state_dim * 2),\n",
531
+ " nn.GELU(),\n",
532
+ " nn.Dropout(0.1),\n",
533
+ " nn.Linear(state_dim * 2, state_dim)\n",
534
+ " )\n",
535
+ " \n",
536
+ " def encode_text(self, texts):\n",
537
+ " \"\"\"\n",
538
+ " Encode text to embeddings using LLM.\n",
539
+ " Uses mean pooling over hidden states.\n",
540
+ " \"\"\"\n",
541
+ " # Tokenize\n",
542
+ " inputs = self.tokenizer(\n",
543
+ " texts,\n",
544
+ " return_tensors='pt',\n",
545
+ " padding=True,\n",
546
+ " truncation=True,\n",
547
+ " max_length=128\n",
548
+ " ).to(self.encoder.device)\n",
549
+ " \n",
550
+ " # Get hidden states\n",
551
+ " with torch.no_grad() if not self.encoder.training else torch.enable_grad():\n",
552
+ " outputs = self.encoder(**inputs)\n",
553
+ " \n",
554
+ " # Mean pooling (exclude padding)\n",
555
+ " attention_mask = inputs['attention_mask'].unsqueeze(-1)\n",
556
+ " hidden_states = outputs.last_hidden_state\n",
557
+ " pooled = (hidden_states * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)\n",
558
+ " \n",
559
+ " return pooled\n",
560
+ " \n",
561
+ " def forward(self, state_texts, action_texts):\n",
562
+ " \"\"\"\n",
563
+ " Predict next state embedding given current state and action.\n",
564
+ " \"\"\"\n",
565
+ " # Encode state and action\n",
566
+ " state_hidden = self.encode_text(state_texts) # [B, hidden_size]\n",
567
+ " action_hidden = self.encode_text(action_texts) # [B, hidden_size]\n",
568
+ " \n",
569
+ " # Project to state space\n",
570
+ " state_emb = self.state_proj(state_hidden) # [B, state_dim]\n",
571
+ " action_emb = self.action_proj(action_hidden) # [B, state_dim]\n",
572
+ " \n",
573
+ " # Combine and predict\n",
574
+ " combined = torch.cat([state_emb, action_emb], dim=-1) # [B, state_dim * 2]\n",
575
+ " predicted_next_state = self.predictor(combined) # [B, state_dim]\n",
576
+ " \n",
577
+ " return predicted_next_state\n",
578
+ " \n",
579
+ " def get_target_embedding(self, next_state_texts):\n",
580
+ " \"\"\"Get target embedding for loss computation.\"\"\"\n",
581
+ " hidden = self.encode_text(next_state_texts)\n",
582
+ " return self.state_proj(hidden)\n",
583
+ "\n",
584
+ "\n",
585
+ "# Create model\n",
586
+ "print(\"Creating Option 2 model...\")\n",
587
+ "model_opt2 = Option2_LLMHiddenStateJEPA(\n",
588
+ " model_name='gpt2',\n",
589
+ " state_dim=512,\n",
590
+ " freeze_encoder=True\n",
591
+ ").to(device)\n",
592
+ "\n",
593
+ "# Count parameters\n",
594
+ "trainable_params = sum(p.numel() for p in model_opt2.parameters() if p.requires_grad)\n",
595
+ "total_params = sum(p.numel() for p in model_opt2.parameters())\n",
596
+ "print(f\"Trainable parameters: {trainable_params:,}\")\n",
597
+ "print(f\"Total parameters: {total_params:,}\")"
598
+ ]
599
+ },
600
+ {
601
+ "cell_type": "code",
602
+ "execution_count": null,
603
+ "metadata": {},
604
+ "outputs": [],
605
+ "source": [
606
+ "def train_option2(model, train_dataset, val_dataset, epochs=10, batch_size=16, lr=1e-3):\n",
607
+ " \"\"\"\n",
608
+ " Training loop for Option 2 model.\n",
609
+ " Uses MSE loss + Cosine similarity loss.\n",
610
+ " \"\"\"\n",
611
+ " model.train()\n",
612
+ " optimizer = torch.optim.AdamW(\n",
613
+ " filter(lambda p: p.requires_grad, model.parameters()),\n",
614
+ " lr=lr,\n",
615
+ " weight_decay=0.01\n",
616
+ " )\n",
617
+ " scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)\n",
618
+ " \n",
619
+ " train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
620
+ " val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)\n",
621
+ " \n",
622
+ " history = {'train_loss': [], 'val_loss': [], 'val_similarity': []}\n",
623
+ " \n",
624
+ " for epoch in range(epochs):\n",
625
+ " # Training\n",
626
+ " model.train()\n",
627
+ " train_losses = []\n",
628
+ " \n",
629
+ " pbar = tqdm(train_loader, desc=f\"Epoch {epoch+1}/{epochs}\")\n",
630
+ " for batch in pbar:\n",
631
+ " states = batch['state']\n",
632
+ " actions = batch['action']\n",
633
+ " next_states = batch['next_state']\n",
634
+ " \n",
635
+ " # Forward pass\n",
636
+ " predicted = model(states, actions)\n",
637
+ " target = model.get_target_embedding(next_states)\n",
638
+ " \n",
639
+ " # Combined loss: MSE + (1 - cosine similarity)\n",
640
+ " mse_loss = F.mse_loss(predicted, target)\n",
641
+ " cos_loss = (1 - F.cosine_similarity(predicted, target, dim=-1)).mean()\n",
642
+ " loss = mse_loss + cos_loss\n",
643
+ " \n",
644
+ " # Backward pass\n",
645
+ " optimizer.zero_grad()\n",
646
+ " loss.backward()\n",
647
+ " torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
648
+ " optimizer.step()\n",
649
+ " \n",
650
+ " train_losses.append(loss.item())\n",
651
+ " pbar.set_postfix({'loss': f'{loss.item():.4f}'})\n",
652
+ " \n",
653
+ " scheduler.step()\n",
654
+ " \n",
655
+ " # Validation\n",
656
+ " model.eval()\n",
657
+ " val_losses = []\n",
658
+ " val_similarities = []\n",
659
+ " \n",
660
+ " with torch.no_grad():\n",
661
+ " for batch in val_loader:\n",
662
+ " states = batch['state']\n",
663
+ " actions = batch['action']\n",
664
+ " next_states = batch['next_state']\n",
665
+ " \n",
666
+ " predicted = model(states, actions)\n",
667
+ " target = model.get_target_embedding(next_states)\n",
668
+ " \n",
669
+ " mse_loss = F.mse_loss(predicted, target)\n",
670
+ " cos_loss = (1 - F.cosine_similarity(predicted, target, dim=-1)).mean()\n",
671
+ " loss = mse_loss + cos_loss\n",
672
+ " \n",
673
+ " val_losses.append(loss.item())\n",
674
+ " val_similarities.append(F.cosine_similarity(predicted, target, dim=-1).mean().item())\n",
675
+ " \n",
676
+ " # Record history\n",
677
+ " history['train_loss'].append(np.mean(train_losses))\n",
678
+ " history['val_loss'].append(np.mean(val_losses))\n",
679
+ " history['val_similarity'].append(np.mean(val_similarities))\n",
680
+ " \n",
681
+ " print(f\"Epoch {epoch+1}: Train Loss={np.mean(train_losses):.4f}, \"\n",
682
+ " f\"Val Loss={np.mean(val_losses):.4f}, Val Similarity={np.mean(val_similarities):.4f}\")\n",
683
+ " \n",
684
+ " return history\n",
685
+ "\n",
686
+ "\n",
687
+ "# Train the model\n",
688
+ "print(\"\\n\" + \"=\"*50)\n",
689
+ "print(\"Training Option 2: LLM Hidden State JEPA\")\n",
690
+ "print(\"=\"*50)\n",
691
+ "history_opt2 = train_option2(model_opt2, train_dataset, val_dataset, epochs=10, batch_size=16)"
692
+ ]
693
+ },
694
+ {
695
+ "cell_type": "code",
696
+ "execution_count": null,
697
+ "metadata": {},
698
+ "outputs": [],
699
+ "source": [
700
+ "# Test Option 2\n",
701
+ "test_model(model_opt2, test_samples)"
702
+ ]
703
+ },
704
+ {
705
+ "cell_type": "markdown",
706
+ "metadata": {},
707
+ "source": [
708
+ "---\n",
709
+ "\n",
710
+ "# 🔴 Option 3: Full Autoencoder Approach (Most Powerful)\n",
711
+ "\n",
712
+ "Trains a full state autoencoder for domain-specific representations.\n",
713
+ "\n",
714
+ "**Pros:**\n",
715
+ "- Best domain adaptation\n",
716
+ "- Learnable encoder captures task-specific features\n",
717
+ "- Highest potential accuracy\n",
718
+ "\n",
719
+ "**Cons:**\n",
720
+ "- Requires more training data\n",
721
+ "- Longer training time\n",
722
+ "- More complex to tune"
723
+ ]
724
+ },
725
+ {
726
+ "cell_type": "code",
727
+ "execution_count": null,
728
+ "metadata": {},
729
+ "outputs": [],
730
+ "source": [
731
+ "class Option3_AutoencoderJEPA(nn.Module):\n",
732
+ " \"\"\"\n",
733
+ " JEPA-style world model with learned state autoencoder.\n",
734
+ " \n",
735
+ " Architecture:\n",
736
+ " - State Encoder: Trainable encoder that learns domain-specific embeddings\n",
737
+ " - State Decoder: Reconstructs text from embeddings (for training)\n",
738
+ " - Predictor: Transformer that predicts next state in latent space\n",
739
+ " \"\"\"\n",
740
+ " \n",
741
+ " def __init__(\n",
742
+ " self,\n",
743
+ " model_name='gpt2',\n",
744
+ " state_dim=256,\n",
745
+ " predictor_layers=3,\n",
746
+ " predictor_heads=4\n",
747
+ " ):\n",
748
+ " super().__init__()\n",
749
+ " \n",
750
+ " # Tokenizer\n",
751
+ " self.tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
752
+ " self.tokenizer.pad_token = self.tokenizer.eos_token\n",
753
+ " \n",
754
+ " # Base LLM for encoding (will be fine-tuned)\n",
755
+ " self.base_llm = AutoModel.from_pretrained(model_name)\n",
756
+ " self.hidden_size = self.base_llm.config.hidden_size\n",
757
+ " self.vocab_size = self.base_llm.config.vocab_size\n",
758
+ " \n",
759
+ " self.state_dim = state_dim\n",
760
+ " \n",
761
+ " # State Encoder: LLM hidden → compressed state\n",
762
+ " self.state_encoder = nn.Sequential(\n",
763
+ " nn.Linear(self.hidden_size, self.hidden_size // 2),\n",
764
+ " nn.LayerNorm(self.hidden_size // 2),\n",
765
+ " nn.GELU(),\n",
766
+ " nn.Linear(self.hidden_size // 2, state_dim),\n",
767
+ " nn.LayerNorm(state_dim)\n",
768
+ " )\n",
769
+ " \n",
770
+ " # State Decoder: compressed state → reconstruction\n",
771
+ " self.state_decoder = nn.Sequential(\n",
772
+ " nn.Linear(state_dim, self.hidden_size // 2),\n",
773
+ " nn.LayerNorm(self.hidden_size // 2),\n",
774
+ " nn.GELU(),\n",
775
+ " nn.Linear(self.hidden_size // 2, self.hidden_size),\n",
776
+ " nn.LayerNorm(self.hidden_size)\n",
777
+ " )\n",
778
+ " \n",
779
+ " # Action Encoder\n",
780
+ " self.action_encoder = nn.Sequential(\n",
781
+ " nn.Linear(self.hidden_size, state_dim),\n",
782
+ " nn.LayerNorm(state_dim),\n",
783
+ " nn.GELU()\n",
784
+ " )\n",
785
+ " \n",
786
+ " # Transformer Predictor: (state, action) → next_state\n",
787
+ " self.input_proj = nn.Linear(state_dim * 2, state_dim)\n",
788
+ " \n",
789
+ " encoder_layer = nn.TransformerEncoderLayer(\n",
790
+ " d_model=state_dim,\n",
791
+ " nhead=predictor_heads,\n",
792
+ " dim_feedforward=state_dim * 4,\n",
793
+ " dropout=0.1,\n",
794
+ " activation='gelu',\n",
795
+ " batch_first=True\n",
796
+ " )\n",
797
+ " self.predictor = nn.TransformerEncoder(encoder_layer, num_layers=predictor_layers)\n",
798
+ " \n",
799
+ " # Output projection\n",
800
+ " self.output_proj = nn.Linear(state_dim, state_dim)\n",
801
+ " \n",
802
+ " def get_llm_hidden(self, texts):\n",
803
+ " \"\"\"Get LLM hidden states for texts.\"\"\"\n",
804
+ " inputs = self.tokenizer(\n",
805
+ " texts,\n",
806
+ " return_tensors='pt',\n",
807
+ " padding=True,\n",
808
+ " truncation=True,\n",
809
+ " max_length=128\n",
810
+ " ).to(self.base_llm.device)\n",
811
+ " \n",
812
+ " outputs = self.base_llm(**inputs)\n",
813
+ " \n",
814
+ " # Mean pooling\n",
815
+ " attention_mask = inputs['attention_mask'].unsqueeze(-1)\n",
816
+ " hidden_states = outputs.last_hidden_state\n",
817
+ " pooled = (hidden_states * attention_mask).sum(dim=1) / attention_mask.sum(dim=1)\n",
818
+ " \n",
819
+ " return pooled\n",
820
+ " \n",
821
+ " def encode_state(self, texts):\n",
822
+ " \"\"\"Encode text to state embedding.\"\"\"\n",
823
+ " hidden = self.get_llm_hidden(texts)\n",
824
+ " return self.state_encoder(hidden)\n",
825
+ " \n",
826
+ " def encode_action(self, texts):\n",
827
+ " \"\"\"Encode action text to action embedding.\"\"\"\n",
828
+ " hidden = self.get_llm_hidden(texts)\n",
829
+ " return self.action_encoder(hidden)\n",
830
+ " \n",
831
+ " def decode_state(self, state_emb):\n",
832
+ " \"\"\"Decode state embedding back to hidden space (for reconstruction loss).\"\"\"\n",
833
+ " return self.state_decoder(state_emb)\n",
834
+ " \n",
835
+ " def forward(self, state_texts, action_texts):\n",
836
+ " \"\"\"\n",
837
+ " Predict next state embedding given current state and action.\n",
838
+ " \"\"\"\n",
839
+ " # Encode\n",
840
+ " state_emb = self.encode_state(state_texts) # [B, state_dim]\n",
841
+ " action_emb = self.encode_action(action_texts) # [B, state_dim]\n",
842
+ " \n",
843
+ " # Combine\n",
844
+ " combined = torch.cat([state_emb, action_emb], dim=-1) # [B, state_dim * 2]\n",
845
+ " combined = self.input_proj(combined) # [B, state_dim]\n",
846
+ " \n",
847
+ " # Add sequence dimension\n",
848
+ " combined = combined.unsqueeze(1) # [B, 1, state_dim]\n",
849
+ " \n",
850
+ " # Predict through transformer\n",
851
+ " predicted = self.predictor(combined) # [B, 1, state_dim]\n",
852
+ " predicted = self.output_proj(predicted.squeeze(1)) # [B, state_dim]\n",
853
+ " \n",
854
+ " return predicted\n",
855
+ " \n",
856
+ " def forward_with_reconstruction(self, state_texts, action_texts, next_state_texts):\n",
857
+ " \"\"\"\n",
858
+ " Forward pass with reconstruction for training.\n",
859
+ " Returns: predicted_next_state, target_next_state, reconstruction of current state\n",
860
+ " \"\"\"\n",
861
+ " # Get all embeddings\n",
862
+ " state_emb = self.encode_state(state_texts)\n",
863
+ " action_emb = self.encode_action(action_texts)\n",
864
+ " target_emb = self.encode_state(next_state_texts)\n",
865
+ " \n",
866
+ " # Predict next state\n",
867
+ " combined = torch.cat([state_emb, action_emb], dim=-1)\n",
868
+ " combined = self.input_proj(combined)\n",
869
+ " combined = combined.unsqueeze(1)\n",
870
+ " predicted = self.predictor(combined)\n",
871
+ " predicted_next = self.output_proj(predicted.squeeze(1))\n",
872
+ " \n",
873
+ " # Reconstruction of current state (for autoencoder regularization)\n",
874
+ " state_reconstructed = self.decode_state(state_emb)\n",
875
+ " state_hidden_original = self.get_llm_hidden(state_texts)\n",
876
+ " \n",
877
+ " return predicted_next, target_emb, state_reconstructed, state_hidden_original\n",
878
+ "\n",
879
+ "\n",
880
+ "# Create model\n",
881
+ "print(\"Creating Option 3 model...\")\n",
882
+ "model_opt3 = Option3_AutoencoderJEPA(\n",
883
+ " model_name='gpt2',\n",
884
+ " state_dim=256,\n",
885
+ " predictor_layers=3,\n",
886
+ " predictor_heads=4\n",
887
+ ").to(device)\n",
888
+ "\n",
889
+ "# Count parameters\n",
890
+ "trainable_params = sum(p.numel() for p in model_opt3.parameters() if p.requires_grad)\n",
891
+ "total_params = sum(p.numel() for p in model_opt3.parameters())\n",
892
+ "print(f\"Trainable parameters: {trainable_params:,}\")\n",
893
+ "print(f\"Total parameters: {total_params:,}\")"
894
+ ]
895
+ },
896
+ {
897
+ "cell_type": "code",
898
+ "execution_count": null,
899
+ "metadata": {},
900
+ "outputs": [],
901
+ "source": [
902
+ "def train_option3(model, train_dataset, val_dataset, epochs=15, batch_size=16, lr=5e-4):\n",
903
+ " \"\"\"\n",
904
+ " Training loop for Option 3 model.\n",
905
+ " Uses: prediction loss + reconstruction loss + cosine similarity loss.\n",
906
+ " \"\"\"\n",
907
+ " model.train()\n",
908
+ " optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)\n",
909
+ " scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)\n",
910
+ " \n",
911
+ " train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
912
+ " val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)\n",
913
+ " \n",
914
+ " history = {'train_loss': [], 'val_loss': [], 'val_similarity': []}\n",
915
+ " \n",
916
+ " for epoch in range(epochs):\n",
917
+ " # Training\n",
918
+ " model.train()\n",
919
+ " train_losses = []\n",
920
+ " \n",
921
+ " pbar = tqdm(train_loader, desc=f\"Epoch {epoch+1}/{epochs}\")\n",
922
+ " for batch in pbar:\n",
923
+ " states = batch['state']\n",
924
+ " actions = batch['action']\n",
925
+ " next_states = batch['next_state']\n",
926
+ " \n",
927
+ " # Forward pass with reconstruction\n",
928
+ " predicted_next, target_next, state_recon, state_orig = model.forward_with_reconstruction(\n",
929
+ " states, actions, next_states\n",
930
+ " )\n",
931
+ " \n",
932
+ " # Prediction loss (main objective)\n",
933
+ " pred_mse = F.mse_loss(predicted_next, target_next)\n",
934
+ " pred_cos = (1 - F.cosine_similarity(predicted_next, target_next, dim=-1)).mean()\n",
935
+ " \n",
936
+ " # Reconstruction loss (regularization)\n",
937
+ " recon_loss = F.mse_loss(state_recon, state_orig.detach())\n",
938
+ " \n",
939
+ " # Combined loss\n",
940
+ " loss = pred_mse + pred_cos + 0.1 * recon_loss\n",
941
+ " \n",
942
+ " # Backward pass\n",
943
+ " optimizer.zero_grad()\n",
944
+ " loss.backward()\n",
945
+ " torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
946
+ " optimizer.step()\n",
947
+ " \n",
948
+ " train_losses.append(loss.item())\n",
949
+ " pbar.set_postfix({'loss': f'{loss.item():.4f}'})\n",
950
+ " \n",
951
+ " scheduler.step()\n",
952
+ " \n",
953
+ " # Validation\n",
954
+ " model.eval()\n",
955
+ " val_losses = []\n",
956
+ " val_similarities = []\n",
957
+ " \n",
958
+ " with torch.no_grad():\n",
959
+ " for batch in val_loader:\n",
960
+ " states = batch['state']\n",
961
+ " actions = batch['action']\n",
962
+ " next_states = batch['next_state']\n",
963
+ " \n",
964
+ " predicted_next, target_next, state_recon, state_orig = model.forward_with_reconstruction(\n",
965
+ " states, actions, next_states\n",
966
+ " )\n",
967
+ " \n",
968
+ " pred_mse = F.mse_loss(predicted_next, target_next)\n",
969
+ " pred_cos = (1 - F.cosine_similarity(predicted_next, target_next, dim=-1)).mean()\n",
970
+ " recon_loss = F.mse_loss(state_recon, state_orig)\n",
971
+ " loss = pred_mse + pred_cos + 0.1 * recon_loss\n",
972
+ " \n",
973
+ " val_losses.append(loss.item())\n",
974
+ " val_similarities.append(F.cosine_similarity(predicted_next, target_next, dim=-1).mean().item())\n",
975
+ " \n",
976
+ " # Record history\n",
977
+ " history['train_loss'].append(np.mean(train_losses))\n",
978
+ " history['val_loss'].append(np.mean(val_losses))\n",
979
+ " history['val_similarity'].append(np.mean(val_similarities))\n",
980
+ " \n",
981
+ " print(f\"Epoch {epoch+1}: Train Loss={np.mean(train_losses):.4f}, \"\n",
982
+ " f\"Val Loss={np.mean(val_losses):.4f}, Val Similarity={np.mean(val_similarities):.4f}\")\n",
983
+ " \n",
984
+ " return history\n",
985
+ "\n",
986
+ "\n",
987
+ "# Train the model\n",
988
+ "print(\"\\n\" + \"=\"*50)\n",
989
+ "print(\"Training Option 3: Autoencoder JEPA\")\n",
990
+ "print(\"=\"*50)\n",
991
+ "history_opt3 = train_option3(model_opt3, train_dataset, val_dataset, epochs=15, batch_size=16)"
992
+ ]
993
+ },
994
+ {
995
+ "cell_type": "code",
996
+ "execution_count": null,
997
+ "metadata": {},
998
+ "outputs": [],
999
+ "source": [
1000
+ "def test_model_opt3(model, test_samples):\n",
1001
+ " \"\"\"Test Option 3 model.\"\"\"\n",
1002
+ " model.eval()\n",
1003
+ " \n",
1004
+ " print(\"\\n\" + \"=\"*60)\n",
1005
+ " print(\"Option 3 Model Predictions\")\n",
1006
+ " print(\"=\"*60)\n",
1007
+ " \n",
1008
+ " for sample in test_samples:\n",
1009
+ " state = sample['state']\n",
1010
+ " action = sample['action']\n",
1011
+ " actual_next = sample['next_state']\n",
1012
+ " \n",
1013
+ " with torch.no_grad():\n",
1014
+ " # Get prediction\n",
1015
+ " predicted_emb = model([state], [action])\n",
1016
+ " actual_emb = model.encode_state([actual_next])\n",
1017
+ " \n",
1018
+ " # Compute similarity\n",
1019
+ " similarity = F.cosine_similarity(predicted_emb, actual_emb, dim=-1).item()\n",
1020
+ " \n",
1021
+ " print(f\"\\nState: {state}\")\n",
1022
+ " print(f\"Action: {action}\")\n",
1023
+ " print(f\"Actual Next State: {actual_next}\")\n",
1024
+ " print(f\"Prediction Similarity: {similarity:.4f} {'✓' if similarity > 0.8 else '✗'}\")\n",
1025
+ "\n",
1026
+ "\n",
1027
+ "# Test Option 3\n",
1028
+ "test_model_opt3(model_opt3, test_samples)"
1029
+ ]
1030
+ },
1031
+ {
1032
+ "cell_type": "markdown",
1033
+ "metadata": {},
1034
+ "source": [
1035
+ "---\n",
1036
+ "\n",
1037
+ "# 📊 Compare All Three Approaches"
1038
+ ]
1039
+ },
1040
+ {
1041
+ "cell_type": "code",
1042
+ "execution_count": null,
1043
+ "metadata": {},
1044
+ "outputs": [],
1045
+ "source": [
1046
+ "# Plot comparison\n",
1047
+ "fig, axes = plt.subplots(1, 3, figsize=(15, 4))\n",
1048
+ "\n",
1049
+ "# Training Loss\n",
1050
+ "axes[0].plot(history_opt1['train_loss'], label='Option 1: Sentence Encoder', marker='o')\n",
1051
+ "axes[0].plot(history_opt2['train_loss'], label='Option 2: LLM Hidden States', marker='s')\n",
1052
+ "axes[0].plot(history_opt3['train_loss'], label='Option 3: Autoencoder', marker='^')\n",
1053
+ "axes[0].set_xlabel('Epoch')\n",
1054
+ "axes[0].set_ylabel('Training Loss')\n",
1055
+ "axes[0].set_title('Training Loss Comparison')\n",
1056
+ "axes[0].legend()\n",
1057
+ "axes[0].grid(True, alpha=0.3)\n",
1058
+ "\n",
1059
+ "# Validation Loss\n",
1060
+ "axes[1].plot(history_opt1['val_loss'], label='Option 1', marker='o')\n",
1061
+ "axes[1].plot(history_opt2['val_loss'], label='Option 2', marker='s')\n",
1062
+ "axes[1].plot(history_opt3['val_loss'], label='Option 3', marker='^')\n",
1063
+ "axes[1].set_xlabel('Epoch')\n",
1064
+ "axes[1].set_ylabel('Validation Loss')\n",
1065
+ "axes[1].set_title('Validation Loss Comparison')\n",
1066
+ "axes[1].legend()\n",
1067
+ "axes[1].grid(True, alpha=0.3)\n",
1068
+ "\n",
1069
+ "# Validation Similarity\n",
1070
+ "axes[2].plot(history_opt1['val_similarity'], label='Option 1', marker='o')\n",
1071
+ "axes[2].plot(history_opt2['val_similarity'], label='Option 2', marker='s')\n",
1072
+ "axes[2].plot(history_opt3['val_similarity'], label='Option 3', marker='^')\n",
1073
+ "axes[2].set_xlabel('Epoch')\n",
1074
+ "axes[2].set_ylabel('Cosine Similarity')\n",
1075
+ "axes[2].set_title('Validation Similarity (Higher = Better)')\n",
1076
+ "axes[2].legend()\n",
1077
+ "axes[2].grid(True, alpha=0.3)\n",
1078
+ "axes[2].set_ylim([0, 1])\n",
1079
+ "\n",
1080
+ "plt.tight_layout()\n",
1081
+ "plt.savefig('jepa_comparison.png', dpi=150, bbox_inches='tight')\n",
1082
+ "plt.show()\n",
1083
+ "\n",
1084
+ "# Print final metrics\n",
1085
+ "print(\"\\n\" + \"=\"*60)\n",
1086
+ "print(\"Final Metrics Comparison\")\n",
1087
+ "print(\"=\"*60)\n",
1088
+ "print(f\"{'Model':<30} {'Val Loss':<15} {'Val Similarity':<15}\")\n",
1089
+ "print(\"-\"*60)\n",
1090
+ "print(f\"{'Option 1: Sentence Encoder':<30} {history_opt1['val_loss'][-1]:<15.4f} {history_opt1['val_similarity'][-1]:<15.4f}\")\n",
1091
+ "print(f\"{'Option 2: LLM Hidden States':<30} {history_opt2['val_loss'][-1]:<15.4f} {history_opt2['val_similarity'][-1]:<15.4f}\")\n",
1092
+ "print(f\"{'Option 3: Autoencoder':<30} {history_opt3['val_loss'][-1]:<15.4f} {history_opt3['val_similarity'][-1]:<15.4f}\")"
1093
+ ]
1094
+ },
1095
+ {
1096
+ "cell_type": "markdown",
1097
+ "metadata": {},
1098
+ "source": [
1099
+ "---\n",
1100
+ "\n",
1101
+ "# 🚀 Interactive Demo: Try Your Own State Transitions"
1102
+ ]
1103
+ },
1104
+ {
1105
+ "cell_type": "code",
1106
+ "execution_count": null,
1107
+ "metadata": {},
1108
+ "outputs": [],
1109
+ "source": [
1110
+ "def interactive_demo(model, model_name=\"Model\"):\n",
1111
+ " \"\"\"\n",
1112
+ " Interactive demo to test state transitions.\n",
1113
+ " \"\"\"\n",
1114
+ " print(f\"\\n{'='*60}\")\n",
1115
+ " print(f\"Interactive Demo: {model_name}\")\n",
1116
+ " print(\"=\"*60)\n",
1117
+ " print(\"\\nEnter a state and action to predict the next state.\")\n",
1118
+ " print(\"Type 'quit' to exit.\\n\")\n",
1119
+ " \n",
1120
+ " # Pre-compute embeddings for all known states for nearest neighbor search\n",
1121
+ " known_states = list(set(\n",
1122
+ " [s['state'] for s in train_dataset.samples] + \n",
1123
+ " [s['next_state'] for s in train_dataset.samples]\n",
1124
+ " ))\n",
1125
+ " \n",
1126
+ " model.eval()\n",
1127
+ " with torch.no_grad():\n",
1128
+ " if hasattr(model, 'encode_state'):\n",
1129
+ " known_embeddings = model.encode_state(known_states)\n",
1130
+ " else:\n",
1131
+ " known_embeddings = model.get_target_embedding(known_states)\n",
1132
+ " \n",
1133
+ " while True:\n",
1134
+ " state = input(\"\\nState: \").strip()\n",
1135
+ " if state.lower() == 'quit':\n",
1136
+ " break\n",
1137
+ " \n",
1138
+ " action = input(\"Action: \").strip()\n",
1139
+ " if action.lower() == 'quit':\n",
1140
+ " break\n",
1141
+ " \n",
1142
+ " with torch.no_grad():\n",
1143
+ " # Predict next state embedding\n",
1144
+ " predicted_emb = model([state], [action])\n",
1145
+ " \n",
1146
+ " # Find nearest known state\n",
1147
+ " similarities = F.cosine_similarity(\n",
1148
+ " predicted_emb.unsqueeze(1),\n",
1149
+ " known_embeddings.unsqueeze(0),\n",
1150
+ " dim=-1\n",
1151
+ " )\n",
1152
+ " \n",
1153
+ " top_k = 3\n",
1154
+ " top_indices = similarities[0].topk(top_k).indices\n",
1155
+ " top_sims = similarities[0].topk(top_k).values\n",
1156
+ " \n",
1157
+ " print(\"\\nPredicted Next States (by similarity):\")\n",
1158
+ " for i, (idx, sim) in enumerate(zip(top_indices, top_sims)):\n",
1159
+ " print(f\" {i+1}. [{sim:.4f}] {known_states[idx]}\")\n",
1160
+ "\n",
1161
+ "\n",
1162
+ "# Run demo with best model\n",
1163
+ "print(\"\\nRunning demo with Option 1 (Sentence Encoder)...\")\n",
1164
+ "print(\"\\nExample inputs to try:\")\n",
1165
+ "print(\" State: 'Document is in draft status with 1 section'\")\n",
1166
+ "print(\" Action: 'User creates new section'\")\n",
1167
+ "\n",
1168
+ "# Uncomment to run interactive demo:\n",
1169
+ "# interactive_demo(model_opt1, \"Option 1: Sentence Encoder\")"
1170
+ ]
1171
+ },
1172
+ {
1173
+ "cell_type": "markdown",
1174
+ "metadata": {},
1175
+ "source": [
1176
+ "---\n",
1177
+ "\n",
1178
+ "# 💾 Save Models"
1179
+ ]
1180
+ },
1181
+ {
1182
+ "cell_type": "code",
1183
+ "execution_count": null,
1184
+ "metadata": {},
1185
+ "outputs": [],
1186
+ "source": [
1187
+ "# Save all three models\n",
1188
+ "torch.save({\n",
1189
+ " 'model_state_dict': model_opt1.state_dict(),\n",
1190
+ " 'history': history_opt1,\n",
1191
+ "}, 'jepa_option1_sentence_encoder.pt')\n",
1192
+ "\n",
1193
+ "torch.save({\n",
1194
+ " 'model_state_dict': model_opt2.state_dict(),\n",
1195
+ " 'history': history_opt2,\n",
1196
+ "}, 'jepa_option2_llm_hidden.pt')\n",
1197
+ "\n",
1198
+ "torch.save({\n",
1199
+ " 'model_state_dict': model_opt3.state_dict(),\n",
1200
+ " 'history': history_opt3,\n",
1201
+ "}, 'jepa_option3_autoencoder.pt')\n",
1202
+ "\n",
1203
+ "print(\"Models saved!\")\n",
1204
+ "print(\" - jepa_option1_sentence_encoder.pt\")\n",
1205
+ "print(\" - jepa_option2_llm_hidden.pt\")\n",
1206
+ "print(\" - jepa_option3_autoencoder.pt\")"
1207
+ ]
1208
+ },
1209
+ {
1210
+ "cell_type": "markdown",
1211
+ "metadata": {},
1212
+ "source": [
1213
+ "---\n",
1214
+ "\n",
1215
+ "# 📝 Summary & Next Steps\n",
1216
+ "\n",
1217
+ "## What We Built\n",
1218
+ "\n",
1219
+ "Three JEPA-style world models that predict state consequences:\n",
1220
+ "\n",
1221
+ "| Option | Encoder | Complexity | Best For |\n",
1222
+ "|--------|---------|------------|----------|\n",
1223
+ "| 1 | Pre-trained SentenceTransformer | Simplest | Quick prototyping |\n",
1224
+ "| 2 | Frozen LLM + trainable head | Medium | General domains |\n",
1225
+ "| 3 | Trainable autoencoder | Complex | Domain-specific |\n",
1226
+ "\n",
1227
+ "## Key Differences from Normal LLMs\n",
1228
+ "\n",
1229
+ "- **Input:** State + Action embeddings (not tokens)\n",
1230
+ "- **Output:** State embeddings (not vocabulary logits)\n",
1231
+ "- **Loss:** MSE + Cosine Similarity (not CrossEntropy)\n",
1232
+ "- **Generation:** Single-shot prediction (not autoregressive)\n",
1233
+ "\n",
1234
+ "## Next Steps\n",
1235
+ "\n",
1236
+ "1. **Scale up:** Use larger base models (Llama, Mistral)\n",
1237
+ "2. **Real data:** Replace synthetic data with actual enterprise logs\n",
1238
+ "3. **Multi-step:** Chain predictions for trajectory forecasting\n",
1239
+ "4. **Planning:** Use predicted states for action selection\n",
1240
+ "5. **Continuous learning:** Add test-time training (TTT)"
1241
+ ]
1242
+ }
1243
+ ],
1244
+ "metadata": {
1245
+ "kernelspec": {
1246
+ "display_name": "Python 3",
1247
+ "language": "python",
1248
+ "name": "python3"
1249
+ },
1250
+ "language_info": {
1251
+ "name": "python",
1252
+ "version": "3.10.0"
1253
+ },
1254
+ "accelerator": "GPU"
1255
+ },
1256
+ "nbformat": 4,
1257
+ "nbformat_minor": 4
1258
+ }
jepa_option1_sentence_encoder.ipynb ADDED
@@ -0,0 +1,690 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# 🧠 JEPA-Style LLM - Option 1: Sentence Encoder Approach\n",
8
+ "\n",
9
+ "**The Simplest Path: Use pre-trained sentence embeddings as your state space**\n",
10
+ "\n",
11
+ "This notebook demonstrates how to make a decoder-only transformer act like a JEPA world model:\n",
12
+ "- Input: State embedding + Action embedding\n",
13
+ "- Output: Predicted next state embedding (NOT tokens)\n",
14
+ "- Loss: MSE in embedding space\n",
15
+ "\n",
16
+ "**Key Insight:** We're predicting *consequences of actions* in a continuous space, not generating text."
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": null,
22
+ "metadata": {},
23
+ "outputs": [],
24
+ "source": [
25
+ "# Install dependencies\n",
26
+ "!pip install -q transformers accelerate sentence-transformers torch datasets wandb"
27
+ ]
28
+ },
29
+ {
30
+ "cell_type": "code",
31
+ "execution_count": null,
32
+ "metadata": {},
33
+ "outputs": [],
34
+ "source": [
35
+ "import torch\n",
36
+ "import torch.nn as nn\n",
37
+ "import torch.nn.functional as F\n",
38
+ "from torch.utils.data import Dataset, DataLoader\n",
39
+ "from sentence_transformers import SentenceTransformer\n",
40
+ "from transformers import AutoModel, AutoTokenizer, AutoConfig\n",
41
+ "import numpy as np\n",
42
+ "from tqdm.auto import tqdm\n",
43
+ "import matplotlib.pyplot as plt\n",
44
+ "\n",
45
+ "# Check GPU\n",
46
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
47
+ "print(f\"Using device: {device}\")"
48
+ ]
49
+ },
50
+ {
51
+ "cell_type": "markdown",
52
+ "metadata": {},
53
+ "source": [
54
+ "## 1. Create Synthetic Training Data\n",
55
+ "\n",
56
+ "We'll simulate an enterprise workflow where:\n",
57
+ "- **State** = description of current document/workflow status\n",
58
+ "- **Action** = what the user does\n",
59
+ "- **Next State** = resulting status after action\n",
60
+ "\n",
61
+ "In production, you'd collect this from real user interactions."
62
+ ]
63
+ },
64
+ {
65
+ "cell_type": "code",
66
+ "execution_count": null,
67
+ "metadata": {},
68
+ "outputs": [],
69
+ "source": [
70
+ "# Synthetic enterprise workflow data\n",
71
+ "# Format: (current_state, action, next_state)\n",
72
+ "\n",
73
+ "WORKFLOW_DATA = [\n",
74
+ " # Document editing workflows\n",
75
+ " (\"Document is empty with no content\", \"User creates new section titled Introduction\", \"Document has one section: Introduction with no content\"),\n",
76
+ " (\"Document has one section: Introduction with no content\", \"User writes 500 words in Introduction\", \"Document has Introduction section with 500 words of content\"),\n",
77
+ " (\"Document has Introduction section with 500 words of content\", \"User adds new section titled Methods\", \"Document has two sections: Introduction (500 words) and Methods (empty)\"),\n",
78
+ " (\"Document has two sections: Introduction (500 words) and Methods (empty)\", \"User writes 300 words in Methods\", \"Document has Introduction (500 words) and Methods (300 words)\"),\n",
79
+ " (\"Document has Introduction (500 words) and Methods (300 words)\", \"User submits document for review\", \"Document is pending review with total 800 words\"),\n",
80
+ " (\"Document is pending review with total 800 words\", \"Reviewer approves document\", \"Document is approved and ready for publication\"),\n",
81
+ " (\"Document is pending review with total 800 words\", \"Reviewer requests changes\", \"Document returned to author with revision requests\"),\n",
82
+ " (\"Document returned to author with revision requests\", \"User makes requested edits\", \"Document revised and ready for re-review\"),\n",
83
+ " \n",
84
+ " # Project management workflows\n",
85
+ " (\"Project has no tasks assigned\", \"Manager creates 5 new tasks\", \"Project has 5 tasks all in pending status\"),\n",
86
+ " (\"Project has 5 tasks all in pending status\", \"Developer starts working on task 1\", \"Project has 1 in-progress task and 4 pending tasks\"),\n",
87
+ " (\"Project has 1 in-progress task and 4 pending tasks\", \"Developer completes task 1\", \"Project has 1 completed task and 4 pending tasks\"),\n",
88
+ " (\"Project has 1 completed task and 4 pending tasks\", \"Developer starts tasks 2 and 3\", \"Project has 1 completed, 2 in-progress, and 2 pending tasks\"),\n",
89
+ " (\"Project has 1 completed, 2 in-progress, and 2 pending tasks\", \"Developer completes all in-progress tasks\", \"Project has 3 completed tasks and 2 pending tasks\"),\n",
90
+ " \n",
91
+ " # Email/Communication workflows\n",
92
+ " (\"Inbox has 10 unread emails\", \"User reads 3 emails\", \"Inbox has 7 unread emails and 3 read emails\"),\n",
93
+ " (\"Inbox has 7 unread emails and 3 read emails\", \"User archives 2 read emails\", \"Inbox has 7 unread, 1 read, and 2 archived emails\"),\n",
94
+ " (\"Inbox has 7 unread, 1 read, and 2 archived emails\", \"User receives new email\", \"Inbox has 8 unread, 1 read, and 2 archived emails\"),\n",
95
+ " (\"Inbox has 8 unread, 1 read, and 2 archived emails\", \"User marks all as read\", \"Inbox has 0 unread, 9 read, and 2 archived emails\"),\n",
96
+ " \n",
97
+ " # Database/Data workflows\n",
98
+ " (\"Database table has 100 records\", \"User inserts 50 new records\", \"Database table has 150 records\"),\n",
99
+ " (\"Database table has 150 records\", \"User deletes 30 records matching filter\", \"Database table has 120 records\"),\n",
100
+ " (\"Database table has 120 records\", \"User updates 20 records with new values\", \"Database table has 120 records with 20 modified\"),\n",
101
+ " (\"Database table has 120 records with 20 modified\", \"User exports table to CSV\", \"Database table unchanged, CSV file created with 120 rows\"),\n",
102
+ " \n",
103
+ " # File system workflows \n",
104
+ " (\"Folder contains 5 files totaling 10MB\", \"User uploads 3 new files of 5MB each\", \"Folder contains 8 files totaling 25MB\"),\n",
105
+ " (\"Folder contains 8 files totaling 25MB\", \"User deletes 2 files of 3MB each\", \"Folder contains 6 files totaling 19MB\"),\n",
106
+ " (\"Folder contains 6 files totaling 19MB\", \"User creates new subfolder\", \"Folder contains 6 files totaling 19MB and 1 empty subfolder\"),\n",
107
+ " (\"Folder contains 6 files totaling 19MB and 1 empty subfolder\", \"User moves 2 files to subfolder\", \"Folder contains 4 files and subfolder with 2 files\"),\n",
108
+ " \n",
109
+ " # Shopping cart workflows\n",
110
+ " (\"Cart is empty with 0 items\", \"User adds product A priced at $50\", \"Cart has 1 item with total $50\"),\n",
111
+ " (\"Cart has 1 item with total $50\", \"User adds product B priced at $30\", \"Cart has 2 items with total $80\"),\n",
112
+ " (\"Cart has 2 items with total $80\", \"User applies 10% discount code\", \"Cart has 2 items with total $72 after discount\"),\n",
113
+ " (\"Cart has 2 items with total $72 after discount\", \"User removes product A\", \"Cart has 1 item with total $27 after discount\"),\n",
114
+ " (\"Cart has 1 item with total $27 after discount\", \"User proceeds to checkout\", \"Order created for 1 item totaling $27\"),\n",
115
+ "]\n",
116
+ "\n",
117
+ "# Augment data with variations\n",
118
+ "def augment_data(data, multiplier=5):\n",
119
+ " \"\"\"Create variations of the data\"\"\"\n",
120
+ " augmented = list(data)\n",
121
+ " \n",
122
+ " # Add slight variations\n",
123
+ " for state, action, next_state in data:\n",
124
+ " for i in range(multiplier - 1):\n",
125
+ " # Add noise phrases\n",
126
+ " prefixes = [\"\", \"Currently, \", \"At this point, \", \"Right now, \"]\n",
127
+ " action_prefixes = [\"\", \"Then \", \"Next, \", \"Subsequently, \"]\n",
128
+ " \n",
129
+ " new_state = np.random.choice(prefixes) + state.lower() if np.random.random() > 0.5 else state\n",
130
+ " new_action = np.random.choice(action_prefixes) + action.lower() if np.random.random() > 0.5 else action\n",
131
+ " \n",
132
+ " augmented.append((new_state, new_action, next_state))\n",
133
+ " \n",
134
+ " return augmented\n",
135
+ "\n",
136
+ "training_data = augment_data(WORKFLOW_DATA, multiplier=10)\n",
137
+ "print(f\"Total training examples: {len(training_data)}\")\n",
138
+ "print(f\"\\nSample:\\n State: {training_data[0][0]}\\n Action: {training_data[0][1]}\\n Next State: {training_data[0][2]}\")"
139
+ ]
140
+ },
141
+ {
142
+ "cell_type": "markdown",
143
+ "metadata": {},
144
+ "source": [
145
+ "## 2. Define the JEPA-Style Model\n",
146
+ "\n",
147
+ "The key architectural change:\n",
148
+ "- **Normal LLM:** `hidden_state → vocab_head → token_logits`\n",
149
+ "- **JEPA LLM:** `hidden_state → state_head → state_embedding`"
150
+ ]
151
+ },
152
+ {
153
+ "cell_type": "code",
154
+ "execution_count": null,
155
+ "metadata": {},
156
+ "outputs": [],
157
+ "source": [
158
+ "class JEPAWorldModel(nn.Module):\n",
159
+ " \"\"\"\n",
160
+ " A decoder-only transformer modified to act like a JEPA world model.\n",
161
+ " \n",
162
+ " Instead of predicting next tokens, it predicts next STATE EMBEDDINGS\n",
163
+ " given current state + action.\n",
164
+ " \"\"\"\n",
165
+ " \n",
166
+ " def __init__(\n",
167
+ " self,\n",
168
+ " sentence_encoder_name: str = \"all-MiniLM-L6-v2\",\n",
169
+ " backbone_name: str = \"gpt2\", # Small model for testing\n",
170
+ " state_dim: int = 384, # MiniLM output dim\n",
171
+ " hidden_dim: int = 512,\n",
172
+ " freeze_sentence_encoder: bool = True\n",
173
+ " ):\n",
174
+ " super().__init__()\n",
175
+ " \n",
176
+ " # Sentence encoder for state/action embeddings\n",
177
+ " self.sentence_encoder = SentenceTransformer(sentence_encoder_name)\n",
178
+ " if freeze_sentence_encoder:\n",
179
+ " for param in self.sentence_encoder.parameters():\n",
180
+ " param.requires_grad = False\n",
181
+ " \n",
182
+ " self.state_dim = state_dim\n",
183
+ " \n",
184
+ " # Backbone transformer (we use its hidden layers, not its LM head)\n",
185
+ " self.backbone = AutoModel.from_pretrained(backbone_name)\n",
186
+ " backbone_hidden = self.backbone.config.hidden_size # 768 for GPT-2\n",
187
+ " \n",
188
+ " # Project state+action embeddings into backbone space\n",
189
+ " self.input_projection = nn.Sequential(\n",
190
+ " nn.Linear(state_dim * 2, hidden_dim),\n",
191
+ " nn.GELU(),\n",
192
+ " nn.LayerNorm(hidden_dim),\n",
193
+ " nn.Linear(hidden_dim, backbone_hidden)\n",
194
+ " )\n",
195
+ " \n",
196
+ " # State prediction head (replaces vocabulary head)\n",
197
+ " # This is the JEPA key: output embeddings, not tokens\n",
198
+ " self.state_predictor = nn.Sequential(\n",
199
+ " nn.Linear(backbone_hidden, hidden_dim),\n",
200
+ " nn.GELU(),\n",
201
+ " nn.LayerNorm(hidden_dim),\n",
202
+ " nn.Linear(hidden_dim, state_dim)\n",
203
+ " )\n",
204
+ " \n",
205
+ " def encode_text(self, texts: list) -> torch.Tensor:\n",
206
+ " \"\"\"Convert text to embeddings using sentence encoder\"\"\"\n",
207
+ " embeddings = self.sentence_encoder.encode(\n",
208
+ " texts, \n",
209
+ " convert_to_tensor=True,\n",
210
+ " show_progress_bar=False\n",
211
+ " )\n",
212
+ " return embeddings\n",
213
+ " \n",
214
+ " def forward(\n",
215
+ " self, \n",
216
+ " state_texts: list,\n",
217
+ " action_texts: list\n",
218
+ " ) -> torch.Tensor:\n",
219
+ " \"\"\"\n",
220
+ " Predict next state embedding given current state and action.\n",
221
+ " \n",
222
+ " Args:\n",
223
+ " state_texts: List of strings describing current states\n",
224
+ " action_texts: List of strings describing actions\n",
225
+ " \n",
226
+ " Returns:\n",
227
+ " predicted_next_state: [batch_size, state_dim] embedding\n",
228
+ " \"\"\"\n",
229
+ " # Encode state and action to embeddings\n",
230
+ " state_emb = self.encode_text(state_texts) # [B, state_dim]\n",
231
+ " action_emb = self.encode_text(action_texts) # [B, state_dim]\n",
232
+ " \n",
233
+ " # Concatenate state and action\n",
234
+ " combined = torch.cat([state_emb, action_emb], dim=-1) # [B, state_dim*2]\n",
235
+ " \n",
236
+ " # Project to backbone space\n",
237
+ " backbone_input = self.input_projection(combined) # [B, backbone_hidden]\n",
238
+ " backbone_input = backbone_input.unsqueeze(1) # [B, 1, backbone_hidden]\n",
239
+ " \n",
240
+ " # Pass through backbone transformer\n",
241
+ " backbone_output = self.backbone(\n",
242
+ " inputs_embeds=backbone_input\n",
243
+ " ).last_hidden_state[:, -1, :] # [B, backbone_hidden]\n",
244
+ " \n",
245
+ " # Predict next state embedding (NOT tokens!)\n",
246
+ " predicted_next_state = self.state_predictor(backbone_output) # [B, state_dim]\n",
247
+ " \n",
248
+ " return predicted_next_state\n",
249
+ " \n",
250
+ " def get_target_embedding(self, next_state_texts: list) -> torch.Tensor:\n",
251
+ " \"\"\"Get target embeddings for loss computation\"\"\"\n",
252
+ " with torch.no_grad():\n",
253
+ " return self.encode_text(next_state_texts)"
254
+ ]
255
+ },
256
+ {
257
+ "cell_type": "markdown",
258
+ "metadata": {},
259
+ "source": [
260
+ "## 3. Create Dataset and DataLoader"
261
+ ]
262
+ },
263
+ {
264
+ "cell_type": "code",
265
+ "execution_count": null,
266
+ "metadata": {},
267
+ "outputs": [],
268
+ "source": [
269
+ "class WorkflowDataset(Dataset):\n",
270
+ " \"\"\"Dataset of (state, action, next_state) triplets\"\"\"\n",
271
+ " \n",
272
+ " def __init__(self, data):\n",
273
+ " self.data = data\n",
274
+ " \n",
275
+ " def __len__(self):\n",
276
+ " return len(self.data)\n",
277
+ " \n",
278
+ " def __getitem__(self, idx):\n",
279
+ " state, action, next_state = self.data[idx]\n",
280
+ " return {\n",
281
+ " 'state': state,\n",
282
+ " 'action': action,\n",
283
+ " 'next_state': next_state\n",
284
+ " }\n",
285
+ "\n",
286
+ "def collate_fn(batch):\n",
287
+ " \"\"\"Collate function that keeps strings as lists\"\"\"\n",
288
+ " return {\n",
289
+ " 'states': [item['state'] for item in batch],\n",
290
+ " 'actions': [item['action'] for item in batch],\n",
291
+ " 'next_states': [item['next_state'] for item in batch]\n",
292
+ " }\n",
293
+ "\n",
294
+ "# Split data\n",
295
+ "np.random.shuffle(training_data)\n",
296
+ "split_idx = int(len(training_data) * 0.9)\n",
297
+ "train_data = training_data[:split_idx]\n",
298
+ "val_data = training_data[split_idx:]\n",
299
+ "\n",
300
+ "train_dataset = WorkflowDataset(train_data)\n",
301
+ "val_dataset = WorkflowDataset(val_data)\n",
302
+ "\n",
303
+ "train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=collate_fn)\n",
304
+ "val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, collate_fn=collate_fn)\n",
305
+ "\n",
306
+ "print(f\"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}\")"
307
+ ]
308
+ },
309
+ {
310
+ "cell_type": "markdown",
311
+ "metadata": {},
312
+ "source": [
313
+ "## 4. Training Loop\n",
314
+ "\n",
315
+ "**The key difference from LLM training:**\n",
316
+ "- LLM: `loss = CrossEntropy(predicted_logits, target_tokens)`\n",
317
+ "- JEPA: `loss = MSE(predicted_embedding, target_embedding)`"
318
+ ]
319
+ },
320
+ {
321
+ "cell_type": "code",
322
+ "execution_count": null,
323
+ "metadata": {},
324
+ "outputs": [],
325
+ "source": [
326
+ "def train_epoch(model, dataloader, optimizer, device):\n",
327
+ " model.train()\n",
328
+ " total_loss = 0\n",
329
+ " \n",
330
+ " for batch in tqdm(dataloader, desc=\"Training\"):\n",
331
+ " # Forward pass\n",
332
+ " predicted_next = model(batch['states'], batch['actions'])\n",
333
+ " \n",
334
+ " # Get target embeddings\n",
335
+ " target_next = model.get_target_embedding(batch['next_states'])\n",
336
+ " \n",
337
+ " # JEPA-style loss: MSE in embedding space\n",
338
+ " loss = F.mse_loss(predicted_next, target_next)\n",
339
+ " \n",
340
+ " # Alternative: Cosine similarity loss\n",
341
+ " # loss = 1 - F.cosine_similarity(predicted_next, target_next).mean()\n",
342
+ " \n",
343
+ " # Backward pass\n",
344
+ " optimizer.zero_grad()\n",
345
+ " loss.backward()\n",
346
+ " torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
347
+ " optimizer.step()\n",
348
+ " \n",
349
+ " total_loss += loss.item()\n",
350
+ " \n",
351
+ " return total_loss / len(dataloader)\n",
352
+ "\n",
353
+ "\n",
354
+ "def validate(model, dataloader, device):\n",
355
+ " model.eval()\n",
356
+ " total_loss = 0\n",
357
+ " total_cosine_sim = 0\n",
358
+ " \n",
359
+ " with torch.no_grad():\n",
360
+ " for batch in dataloader:\n",
361
+ " predicted_next = model(batch['states'], batch['actions'])\n",
362
+ " target_next = model.get_target_embedding(batch['next_states'])\n",
363
+ " \n",
364
+ " loss = F.mse_loss(predicted_next, target_next)\n",
365
+ " cosine_sim = F.cosine_similarity(predicted_next, target_next).mean()\n",
366
+ " \n",
367
+ " total_loss += loss.item()\n",
368
+ " total_cosine_sim += cosine_sim.item()\n",
369
+ " \n",
370
+ " return {\n",
371
+ " 'loss': total_loss / len(dataloader),\n",
372
+ " 'cosine_similarity': total_cosine_sim / len(dataloader)\n",
373
+ " }"
374
+ ]
375
+ },
376
+ {
377
+ "cell_type": "code",
378
+ "execution_count": null,
379
+ "metadata": {},
380
+ "outputs": [],
381
+ "source": [
382
+ "# Initialize model\n",
383
+ "model = JEPAWorldModel(\n",
384
+ " sentence_encoder_name=\"all-MiniLM-L6-v2\",\n",
385
+ " backbone_name=\"gpt2\",\n",
386
+ " state_dim=384,\n",
387
+ " hidden_dim=512\n",
388
+ ")\n",
389
+ "model = model.to(device)\n",
390
+ "\n",
391
+ "# Optimizer\n",
392
+ "optimizer = torch.optim.AdamW(\n",
393
+ " filter(lambda p: p.requires_grad, model.parameters()),\n",
394
+ " lr=1e-4,\n",
395
+ " weight_decay=0.01\n",
396
+ ")\n",
397
+ "\n",
398
+ "# Count parameters\n",
399
+ "trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
400
+ "total_params = sum(p.numel() for p in model.parameters())\n",
401
+ "print(f\"Trainable parameters: {trainable_params:,}\")\n",
402
+ "print(f\"Total parameters: {total_params:,}\")"
403
+ ]
404
+ },
405
+ {
406
+ "cell_type": "code",
407
+ "execution_count": null,
408
+ "metadata": {},
409
+ "outputs": [],
410
+ "source": [
411
+ "# Training\n",
412
+ "num_epochs = 20\n",
413
+ "train_losses = []\n",
414
+ "val_losses = []\n",
415
+ "val_cosine_sims = []\n",
416
+ "\n",
417
+ "for epoch in range(num_epochs):\n",
418
+ " train_loss = train_epoch(model, train_loader, optimizer, device)\n",
419
+ " val_metrics = validate(model, val_loader, device)\n",
420
+ " \n",
421
+ " train_losses.append(train_loss)\n",
422
+ " val_losses.append(val_metrics['loss'])\n",
423
+ " val_cosine_sims.append(val_metrics['cosine_similarity'])\n",
424
+ " \n",
425
+ " print(f\"Epoch {epoch+1}/{num_epochs}\")\n",
426
+ " print(f\" Train Loss: {train_loss:.4f}\")\n",
427
+ " print(f\" Val Loss: {val_metrics['loss']:.4f}\")\n",
428
+ " print(f\" Val Cosine Similarity: {val_metrics['cosine_similarity']:.4f}\")"
429
+ ]
430
+ },
431
+ {
432
+ "cell_type": "code",
433
+ "execution_count": null,
434
+ "metadata": {},
435
+ "outputs": [],
436
+ "source": [
437
+ "# Plot training curves\n",
438
+ "fig, axes = plt.subplots(1, 2, figsize=(12, 4))\n",
439
+ "\n",
440
+ "axes[0].plot(train_losses, label='Train')\n",
441
+ "axes[0].plot(val_losses, label='Validation')\n",
442
+ "axes[0].set_xlabel('Epoch')\n",
443
+ "axes[0].set_ylabel('MSE Loss')\n",
444
+ "axes[0].set_title('Training Progress')\n",
445
+ "axes[0].legend()\n",
446
+ "\n",
447
+ "axes[1].plot(val_cosine_sims, color='green')\n",
448
+ "axes[1].set_xlabel('Epoch')\n",
449
+ "axes[1].set_ylabel('Cosine Similarity')\n",
450
+ "axes[1].set_title('Prediction Quality (higher = better)')\n",
451
+ "\n",
452
+ "plt.tight_layout()\n",
453
+ "plt.show()"
454
+ ]
455
+ },
456
+ {
457
+ "cell_type": "markdown",
458
+ "metadata": {},
459
+ "source": [
460
+ "## 5. Test the World Model: Predict Consequences of Actions"
461
+ ]
462
+ },
463
+ {
464
+ "cell_type": "code",
465
+ "execution_count": null,
466
+ "metadata": {},
467
+ "outputs": [],
468
+ "source": [
469
+ "def predict_next_state(model, current_state: str, action: str, candidate_states: list) -> dict:\n",
470
+ " \"\"\"\n",
471
+ " Given current state and action, predict which candidate state is most likely.\n",
472
+ " \n",
473
+ " This is how JEPA works: predict embedding, then find closest match.\n",
474
+ " \"\"\"\n",
475
+ " model.eval()\n",
476
+ " \n",
477
+ " with torch.no_grad():\n",
478
+ " # Predict next state embedding\n",
479
+ " predicted_emb = model([current_state], [action]) # [1, state_dim]\n",
480
+ " \n",
481
+ " # Get embeddings of all candidate states\n",
482
+ " candidate_embs = model.encode_text(candidate_states) # [N, state_dim]\n",
483
+ " \n",
484
+ " # Compute cosine similarity to find best match\n",
485
+ " similarities = F.cosine_similarity(\n",
486
+ " predicted_emb.expand(len(candidate_states), -1),\n",
487
+ " candidate_embs\n",
488
+ " )\n",
489
+ " \n",
490
+ " # Get rankings\n",
491
+ " rankings = similarities.argsort(descending=True)\n",
492
+ " \n",
493
+ " return {\n",
494
+ " 'predicted_state': candidate_states[rankings[0]],\n",
495
+ " 'confidence': similarities[rankings[0]].item(),\n",
496
+ " 'all_scores': {candidate_states[i]: similarities[i].item() for i in range(len(candidate_states))}\n",
497
+ " }\n",
498
+ "\n",
499
+ "\n",
500
+ "# Test on held-out examples\n",
501
+ "print(\"=\"*80)\n",
502
+ "print(\"TESTING WORLD MODEL PREDICTIONS\")\n",
503
+ "print(\"=\"*80)\n",
504
+ "\n",
505
+ "test_cases = [\n",
506
+ " {\n",
507
+ " 'state': \"Document is empty with no content\",\n",
508
+ " 'action': \"User creates new section titled Introduction\",\n",
509
+ " 'candidates': [\n",
510
+ " \"Document has one section: Introduction with no content\",\n",
511
+ " \"Document is deleted\",\n",
512
+ " \"Document has 500 words\",\n",
513
+ " \"User logged out\"\n",
514
+ " ]\n",
515
+ " },\n",
516
+ " {\n",
517
+ " 'state': \"Cart has 2 items with total $80\",\n",
518
+ " 'action': \"User applies 10% discount code\",\n",
519
+ " 'candidates': [\n",
520
+ " \"Cart has 2 items with total $72 after discount\",\n",
521
+ " \"Cart is empty\",\n",
522
+ " \"Cart has 3 items with total $100\",\n",
523
+ " \"Order was cancelled\"\n",
524
+ " ]\n",
525
+ " },\n",
526
+ " {\n",
527
+ " 'state': \"Inbox has 10 unread emails\",\n",
528
+ " 'action': \"User reads 3 emails\",\n",
529
+ " 'candidates': [\n",
530
+ " \"Inbox has 7 unread emails and 3 read emails\",\n",
531
+ " \"Inbox has 13 unread emails\",\n",
532
+ " \"Inbox is empty\",\n",
533
+ " \"User sent 3 emails\"\n",
534
+ " ]\n",
535
+ " }\n",
536
+ "]\n",
537
+ "\n",
538
+ "for i, test in enumerate(test_cases):\n",
539
+ " print(f\"\\n--- Test {i+1} ---\")\n",
540
+ " print(f\"Current State: {test['state']}\")\n",
541
+ " print(f\"Action: {test['action']}\")\n",
542
+ " \n",
543
+ " result = predict_next_state(model, test['state'], test['action'], test['candidates'])\n",
544
+ " \n",
545
+ " print(f\"\\nPredicted Next State: {result['predicted_state']}\")\n",
546
+ " print(f\"Confidence: {result['confidence']:.4f}\")\n",
547
+ " print(f\"\\nAll scores:\")\n",
548
+ " for state, score in sorted(result['all_scores'].items(), key=lambda x: -x[1]):\n",
549
+ " print(f\" {score:.4f}: {state}\")"
550
+ ]
551
+ },
552
+ {
553
+ "cell_type": "markdown",
554
+ "metadata": {},
555
+ "source": [
556
+ "## 6. Multi-Step Planning: Chain Predictions\n",
557
+ "\n",
558
+ "The power of world models: simulate multiple steps into the future!"
559
+ ]
560
+ },
561
+ {
562
+ "cell_type": "code",
563
+ "execution_count": null,
564
+ "metadata": {},
565
+ "outputs": [],
566
+ "source": [
567
+ "def simulate_trajectory(model, initial_state: str, actions: list, possible_states: list) -> list:\n",
568
+ " \"\"\"\n",
569
+ " Simulate a trajectory of states given a sequence of actions.\n",
570
+ " This is multi-step world model prediction!\n",
571
+ " \"\"\"\n",
572
+ " trajectory = [initial_state]\n",
573
+ " current_state = initial_state\n",
574
+ " \n",
575
+ " for action in actions:\n",
576
+ " result = predict_next_state(model, current_state, action, possible_states)\n",
577
+ " current_state = result['predicted_state']\n",
578
+ " trajectory.append({\n",
579
+ " 'action': action,\n",
580
+ " 'resulting_state': current_state,\n",
581
+ " 'confidence': result['confidence']\n",
582
+ " })\n",
583
+ " \n",
584
+ " return trajectory\n",
585
+ "\n",
586
+ "\n",
587
+ "# Test multi-step planning\n",
588
+ "print(\"\\n\" + \"=\"*80)\n",
589
+ "print(\"MULTI-STEP TRAJECTORY SIMULATION\")\n",
590
+ "print(\"=\"*80)\n",
591
+ "\n",
592
+ "possible_states = [\n",
593
+ " \"Document is empty with no content\",\n",
594
+ " \"Document has one section: Introduction with no content\",\n",
595
+ " \"Document has Introduction section with 500 words of content\",\n",
596
+ " \"Document has two sections: Introduction (500 words) and Methods (empty)\",\n",
597
+ " \"Document has Introduction (500 words) and Methods (300 words)\",\n",
598
+ " \"Document is pending review with total 800 words\",\n",
599
+ " \"Document is approved and ready for publication\",\n",
600
+ " \"Document returned to author with revision requests\",\n",
601
+ "]\n",
602
+ "\n",
603
+ "actions = [\n",
604
+ " \"User creates new section titled Introduction\",\n",
605
+ " \"User writes 500 words in Introduction\",\n",
606
+ " \"User adds new section titled Methods\",\n",
607
+ " \"User writes 300 words in Methods\",\n",
608
+ " \"User submits document for review\"\n",
609
+ "]\n",
610
+ "\n",
611
+ "trajectory = simulate_trajectory(\n",
612
+ " model,\n",
613
+ " initial_state=\"Document is empty with no content\",\n",
614
+ " actions=actions,\n",
615
+ " possible_states=possible_states\n",
616
+ ")\n",
617
+ "\n",
618
+ "print(f\"\\nInitial State: {trajectory[0]}\")\n",
619
+ "for i, step in enumerate(trajectory[1:], 1):\n",
620
+ " print(f\"\\nStep {i}:\")\n",
621
+ " print(f\" Action: {step['action']}\")\n",
622
+ " print(f\" → {step['resulting_state']}\")\n",
623
+ " print(f\" (confidence: {step['confidence']:.4f})\")"
624
+ ]
625
+ },
626
+ {
627
+ "cell_type": "markdown",
628
+ "metadata": {},
629
+ "source": [
630
+ "## 7. Save the Model"
631
+ ]
632
+ },
633
+ {
634
+ "cell_type": "code",
635
+ "execution_count": null,
636
+ "metadata": {},
637
+ "outputs": [],
638
+ "source": [
639
+ "# Save model\n",
640
+ "torch.save({\n",
641
+ " 'model_state_dict': model.state_dict(),\n",
642
+ " 'config': {\n",
643
+ " 'sentence_encoder_name': 'all-MiniLM-L6-v2',\n",
644
+ " 'backbone_name': 'gpt2',\n",
645
+ " 'state_dim': 384,\n",
646
+ " 'hidden_dim': 512\n",
647
+ " }\n",
648
+ "}, 'jepa_world_model_option1.pt')\n",
649
+ "\n",
650
+ "print(\"Model saved to jepa_world_model_option1.pt\")"
651
+ ]
652
+ },
653
+ {
654
+ "cell_type": "markdown",
655
+ "metadata": {},
656
+ "source": [
657
+ "## Summary\n",
658
+ "\n",
659
+ "**What we built:**\n",
660
+ "- A decoder-only transformer that predicts STATE EMBEDDINGS instead of tokens\n",
661
+ "- Input: (current_state, action) pair\n",
662
+ "- Output: predicted next_state embedding\n",
663
+ "- Loss: MSE between predicted and actual state embeddings\n",
664
+ "\n",
665
+ "**This is JEPA-like because:**\n",
666
+ "1. We predict in latent/embedding space, not token space\n",
667
+ "2. We learn the \"physics\" of state transitions\n",
668
+ "3. We can do multi-step planning by chaining predictions\n",
669
+ "\n",
670
+ "**Next steps:**\n",
671
+ "- Use your own enterprise data (state, action, next_state) triplets\n",
672
+ "- Scale up the backbone model\n",
673
+ "- Add uncertainty estimation for planning"
674
+ ]
675
+ }
676
+ ],
677
+ "metadata": {
678
+ "kernelspec": {
679
+ "display_name": "Python 3",
680
+ "language": "python",
681
+ "name": "python3"
682
+ },
683
+ "language_info": {
684
+ "name": "python",
685
+ "version": "3.10.0"
686
+ }
687
+ },
688
+ "nbformat": 4,
689
+ "nbformat_minor": 4
690
+ }
jepa_option2_llm_hidden_states.ipynb ADDED
@@ -0,0 +1,699 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# 🧠 JEPA-Style LLM - Option 2: LLM Hidden States as World Model\n",
8
+ "\n",
9
+ "**Use the LLM's own internal representations as the state space**\n",
10
+ "\n",
11
+ "This approach is more powerful than Option 1 because:\n",
12
+ "- The state encoder and predictor share the same representation space\n",
13
+ "- The LLM learns both to encode states AND predict transitions\n",
14
+ "- No separate sentence encoder needed\n",
15
+ "\n",
16
+ "**Architecture:**\n",
17
+ "```\n",
18
+ "State text → LLM Encoder → State embedding\n",
19
+ "Action text → LLM Encoder → Action embedding \n",
20
+ "State + Action → LLM Predictor → Next State embedding\n",
21
+ "```"
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "code",
26
+ "execution_count": null,
27
+ "metadata": {},
28
+ "outputs": [],
29
+ "source": [
30
+ "# Install dependencies\n",
31
+ "!pip install -q transformers accelerate torch datasets bitsandbytes"
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "code",
36
+ "execution_count": null,
37
+ "metadata": {},
38
+ "outputs": [],
39
+ "source": [
40
+ "import torch\n",
41
+ "import torch.nn as nn\n",
42
+ "import torch.nn.functional as F\n",
43
+ "from torch.utils.data import Dataset, DataLoader\n",
44
+ "from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM\n",
45
+ "import numpy as np\n",
46
+ "from tqdm.auto import tqdm\n",
47
+ "import matplotlib.pyplot as plt\n",
48
+ "\n",
49
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
50
+ "print(f\"Using device: {device}\")"
51
+ ]
52
+ },
53
+ {
54
+ "cell_type": "markdown",
55
+ "metadata": {},
56
+ "source": [
57
+ "## 1. Synthetic Data (Same as Option 1)"
58
+ ]
59
+ },
60
+ {
61
+ "cell_type": "code",
62
+ "execution_count": null,
63
+ "metadata": {},
64
+ "outputs": [],
65
+ "source": [
66
+ "WORKFLOW_DATA = [\n",
67
+ " # Document workflows\n",
68
+ " (\"Document is empty\", \"create introduction section\", \"Document has introduction section\"),\n",
69
+ " (\"Document has introduction section\", \"write 500 words\", \"Document has introduction with 500 words\"),\n",
70
+ " (\"Document has introduction with 500 words\", \"add methods section\", \"Document has introduction and methods sections\"),\n",
71
+ " (\"Document has introduction and methods sections\", \"submit for review\", \"Document pending review\"),\n",
72
+ " (\"Document pending review\", \"reviewer approves\", \"Document approved\"),\n",
73
+ " (\"Document pending review\", \"reviewer rejects\", \"Document needs revision\"),\n",
74
+ " \n",
75
+ " # Task workflows\n",
76
+ " (\"Project has no tasks\", \"create 5 tasks\", \"Project has 5 pending tasks\"),\n",
77
+ " (\"Project has 5 pending tasks\", \"start task 1\", \"Project has 1 active and 4 pending tasks\"),\n",
78
+ " (\"Project has 1 active and 4 pending tasks\", \"complete task 1\", \"Project has 1 done and 4 pending tasks\"),\n",
79
+ " (\"Project has 1 done and 4 pending tasks\", \"start remaining tasks\", \"Project has 1 done and 4 active tasks\"),\n",
80
+ " (\"Project has 1 done and 4 active tasks\", \"complete all tasks\", \"Project complete with 5 done tasks\"),\n",
81
+ " \n",
82
+ " # Shopping cart\n",
83
+ " (\"Cart empty\", \"add item for $50\", \"Cart has 1 item totaling $50\"),\n",
84
+ " (\"Cart has 1 item totaling $50\", \"add item for $30\", \"Cart has 2 items totaling $80\"),\n",
85
+ " (\"Cart has 2 items totaling $80\", \"apply 10% discount\", \"Cart has 2 items totaling $72\"),\n",
86
+ " (\"Cart has 2 items totaling $72\", \"checkout\", \"Order placed for $72\"),\n",
87
+ " (\"Cart has 2 items totaling $80\", \"remove first item\", \"Cart has 1 item totaling $30\"),\n",
88
+ " \n",
89
+ " # Database operations\n",
90
+ " (\"Table has 100 rows\", \"insert 50 rows\", \"Table has 150 rows\"),\n",
91
+ " (\"Table has 150 rows\", \"delete 30 rows\", \"Table has 120 rows\"),\n",
92
+ " (\"Table has 120 rows\", \"update 20 rows\", \"Table has 120 rows with 20 modified\"),\n",
93
+ " \n",
94
+ " # File operations\n",
95
+ " (\"Folder has 5 files\", \"upload 3 files\", \"Folder has 8 files\"),\n",
96
+ " (\"Folder has 8 files\", \"delete 2 files\", \"Folder has 6 files\"),\n",
97
+ " (\"Folder has 6 files\", \"create subfolder\", \"Folder has 6 files and 1 subfolder\"),\n",
98
+ "]\n",
99
+ "\n",
100
+ "# Augment\n",
101
+ "def augment_data(data, multiplier=15):\n",
102
+ " augmented = []\n",
103
+ " for state, action, next_state in data:\n",
104
+ " for _ in range(multiplier):\n",
105
+ " # Randomly add prefixes/variations\n",
106
+ " s = state if np.random.random() > 0.3 else f\"Currently: {state}\"\n",
107
+ " a = action if np.random.random() > 0.3 else f\"User action: {action}\"\n",
108
+ " augmented.append((s, a, next_state))\n",
109
+ " return augmented\n",
110
+ "\n",
111
+ "training_data = augment_data(WORKFLOW_DATA)\n",
112
+ "np.random.shuffle(training_data)\n",
113
+ "print(f\"Total examples: {len(training_data)}\")"
114
+ ]
115
+ },
116
+ {
117
+ "cell_type": "markdown",
118
+ "metadata": {},
119
+ "source": [
120
+ "## 2. JEPA World Model Using LLM Hidden States\n",
121
+ "\n",
122
+ "**Key Idea:** The LLM's hidden states ARE the state embeddings.\n",
123
+ "- Encode text through the LLM, use mean-pooled hidden states\n",
124
+ "- Train a predictor on top to forecast next state embeddings"
125
+ ]
126
+ },
127
+ {
128
+ "cell_type": "code",
129
+ "execution_count": null,
130
+ "metadata": {},
131
+ "outputs": [],
132
+ "source": [
133
+ "class JEPAWorldModelV2(nn.Module):\n",
134
+ " \"\"\"\n",
135
+ " JEPA-style world model using LLM hidden states as state space.\n",
136
+ " \n",
137
+ " The LLM serves dual purpose:\n",
138
+ " 1. State encoder: text → hidden state → state embedding\n",
139
+ " 2. Dynamics backbone: process (state, action) to predict next state\n",
140
+ " \"\"\"\n",
141
+ " \n",
142
+ " def __init__(\n",
143
+ " self,\n",
144
+ " model_name: str = \"gpt2\",\n",
145
+ " state_dim: int = 256,\n",
146
+ " freeze_llm: bool = True # Freeze LLM, train only heads\n",
147
+ " ):\n",
148
+ " super().__init__()\n",
149
+ " \n",
150
+ " # Load LLM and tokenizer\n",
151
+ " self.tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
152
+ " self.tokenizer.pad_token = self.tokenizer.eos_token\n",
153
+ " \n",
154
+ " self.llm = AutoModel.from_pretrained(model_name)\n",
155
+ " self.hidden_size = self.llm.config.hidden_size\n",
156
+ " \n",
157
+ " if freeze_llm:\n",
158
+ " for param in self.llm.parameters():\n",
159
+ " param.requires_grad = False\n",
160
+ " \n",
161
+ " self.state_dim = state_dim\n",
162
+ " \n",
163
+ " # State encoder: LLM hidden → compact state embedding\n",
164
+ " self.state_encoder = nn.Sequential(\n",
165
+ " nn.Linear(self.hidden_size, self.hidden_size // 2),\n",
166
+ " nn.GELU(),\n",
167
+ " nn.LayerNorm(self.hidden_size // 2),\n",
168
+ " nn.Linear(self.hidden_size // 2, state_dim),\n",
169
+ " nn.LayerNorm(state_dim)\n",
170
+ " )\n",
171
+ " \n",
172
+ " # Action encoder (same structure)\n",
173
+ " self.action_encoder = nn.Sequential(\n",
174
+ " nn.Linear(self.hidden_size, self.hidden_size // 2),\n",
175
+ " nn.GELU(),\n",
176
+ " nn.LayerNorm(self.hidden_size // 2),\n",
177
+ " nn.Linear(self.hidden_size // 2, state_dim),\n",
178
+ " nn.LayerNorm(state_dim)\n",
179
+ " )\n",
180
+ " \n",
181
+ " # State dynamics predictor\n",
182
+ " # Input: state_emb + action_emb\n",
183
+ " # Output: predicted next_state_emb\n",
184
+ " self.dynamics_predictor = nn.Sequential(\n",
185
+ " nn.Linear(state_dim * 2, state_dim * 2),\n",
186
+ " nn.GELU(),\n",
187
+ " nn.LayerNorm(state_dim * 2),\n",
188
+ " nn.Linear(state_dim * 2, state_dim),\n",
189
+ " nn.GELU(),\n",
190
+ " nn.LayerNorm(state_dim),\n",
191
+ " nn.Linear(state_dim, state_dim)\n",
192
+ " )\n",
193
+ " \n",
194
+ " def get_llm_embedding(self, texts: list) -> torch.Tensor:\n",
195
+ " \"\"\"Get mean-pooled LLM hidden states for texts\"\"\"\n",
196
+ " tokens = self.tokenizer(\n",
197
+ " texts,\n",
198
+ " return_tensors='pt',\n",
199
+ " padding=True,\n",
200
+ " truncation=True,\n",
201
+ " max_length=128\n",
202
+ " ).to(next(self.llm.parameters()).device)\n",
203
+ " \n",
204
+ " with torch.no_grad() if not self.llm.training else torch.enable_grad():\n",
205
+ " outputs = self.llm(**tokens)\n",
206
+ " hidden_states = outputs.last_hidden_state # [B, seq_len, hidden]\n",
207
+ " \n",
208
+ " # Mean pooling (ignoring padding)\n",
209
+ " attention_mask = tokens['attention_mask'].unsqueeze(-1)\n",
210
+ " sum_hidden = (hidden_states * attention_mask).sum(dim=1)\n",
211
+ " mean_hidden = sum_hidden / attention_mask.sum(dim=1)\n",
212
+ " \n",
213
+ " return mean_hidden # [B, hidden_size]\n",
214
+ " \n",
215
+ " def encode_state(self, state_texts: list) -> torch.Tensor:\n",
216
+ " \"\"\"Encode state text to state embedding\"\"\"\n",
217
+ " llm_emb = self.get_llm_embedding(state_texts)\n",
218
+ " return self.state_encoder(llm_emb)\n",
219
+ " \n",
220
+ " def encode_action(self, action_texts: list) -> torch.Tensor:\n",
221
+ " \"\"\"Encode action text to action embedding\"\"\"\n",
222
+ " llm_emb = self.get_llm_embedding(action_texts)\n",
223
+ " return self.action_encoder(llm_emb)\n",
224
+ " \n",
225
+ " def forward(\n",
226
+ " self,\n",
227
+ " state_texts: list,\n",
228
+ " action_texts: list\n",
229
+ " ) -> torch.Tensor:\n",
230
+ " \"\"\"\n",
231
+ " Predict next state embedding from current state and action.\n",
232
+ " \n",
233
+ " This is the JEPA forward pass:\n",
234
+ " (state, action) → predicted_next_state_embedding\n",
235
+ " \"\"\"\n",
236
+ " # Encode state and action\n",
237
+ " state_emb = self.encode_state(state_texts) # [B, state_dim]\n",
238
+ " action_emb = self.encode_action(action_texts) # [B, state_dim]\n",
239
+ " \n",
240
+ " # Concatenate for dynamics prediction\n",
241
+ " combined = torch.cat([state_emb, action_emb], dim=-1) # [B, state_dim*2]\n",
242
+ " \n",
243
+ " # Predict next state\n",
244
+ " predicted_next_state = self.dynamics_predictor(combined) # [B, state_dim]\n",
245
+ " \n",
246
+ " return predicted_next_state\n",
247
+ " \n",
248
+ " def get_target_embedding(self, next_state_texts: list) -> torch.Tensor:\n",
249
+ " \"\"\"Get target state embedding for loss computation\"\"\"\n",
250
+ " return self.encode_state(next_state_texts)"
251
+ ]
252
+ },
253
+ {
254
+ "cell_type": "markdown",
255
+ "metadata": {},
256
+ "source": [
257
+ "## 3. Dataset and DataLoader"
258
+ ]
259
+ },
260
+ {
261
+ "cell_type": "code",
262
+ "execution_count": null,
263
+ "metadata": {},
264
+ "outputs": [],
265
+ "source": [
266
+ "class WorkflowDataset(Dataset):\n",
267
+ " def __init__(self, data):\n",
268
+ " self.data = data\n",
269
+ " \n",
270
+ " def __len__(self):\n",
271
+ " return len(self.data)\n",
272
+ " \n",
273
+ " def __getitem__(self, idx):\n",
274
+ " state, action, next_state = self.data[idx]\n",
275
+ " return {'state': state, 'action': action, 'next_state': next_state}\n",
276
+ "\n",
277
+ "def collate_fn(batch):\n",
278
+ " return {\n",
279
+ " 'states': [item['state'] for item in batch],\n",
280
+ " 'actions': [item['action'] for item in batch],\n",
281
+ " 'next_states': [item['next_state'] for item in batch]\n",
282
+ " }\n",
283
+ "\n",
284
+ "# Split\n",
285
+ "split_idx = int(len(training_data) * 0.9)\n",
286
+ "train_data = training_data[:split_idx]\n",
287
+ "val_data = training_data[split_idx:]\n",
288
+ "\n",
289
+ "train_loader = DataLoader(\n",
290
+ " WorkflowDataset(train_data), \n",
291
+ " batch_size=8, \n",
292
+ " shuffle=True, \n",
293
+ " collate_fn=collate_fn\n",
294
+ ")\n",
295
+ "val_loader = DataLoader(\n",
296
+ " WorkflowDataset(val_data), \n",
297
+ " batch_size=8, \n",
298
+ " shuffle=False, \n",
299
+ " collate_fn=collate_fn\n",
300
+ ")\n",
301
+ "\n",
302
+ "print(f\"Train: {len(train_loader)} batches, Val: {len(val_loader)} batches\")"
303
+ ]
304
+ },
305
+ {
306
+ "cell_type": "markdown",
307
+ "metadata": {},
308
+ "source": [
309
+ "## 4. Training with JEPA-style Loss"
310
+ ]
311
+ },
312
+ {
313
+ "cell_type": "code",
314
+ "execution_count": null,
315
+ "metadata": {},
316
+ "outputs": [],
317
+ "source": [
318
+ "class JEPALoss(nn.Module):\n",
319
+ " \"\"\"\n",
320
+ " Combined loss for JEPA training:\n",
321
+ " - MSE: Mean squared error in embedding space\n",
322
+ " - Cosine: Similarity loss\n",
323
+ " - Contrastive: Push apart wrong predictions\n",
324
+ " \"\"\"\n",
325
+ " def __init__(self, mse_weight=1.0, cosine_weight=0.5):\n",
326
+ " super().__init__()\n",
327
+ " self.mse_weight = mse_weight\n",
328
+ " self.cosine_weight = cosine_weight\n",
329
+ " \n",
330
+ " def forward(self, predicted: torch.Tensor, target: torch.Tensor) -> dict:\n",
331
+ " # MSE loss\n",
332
+ " mse_loss = F.mse_loss(predicted, target)\n",
333
+ " \n",
334
+ " # Cosine similarity loss (maximize similarity = minimize 1 - sim)\n",
335
+ " cosine_sim = F.cosine_similarity(predicted, target, dim=-1)\n",
336
+ " cosine_loss = (1 - cosine_sim).mean()\n",
337
+ " \n",
338
+ " # Combined loss\n",
339
+ " total_loss = self.mse_weight * mse_loss + self.cosine_weight * cosine_loss\n",
340
+ " \n",
341
+ " return {\n",
342
+ " 'total': total_loss,\n",
343
+ " 'mse': mse_loss,\n",
344
+ " 'cosine_loss': cosine_loss,\n",
345
+ " 'cosine_sim': cosine_sim.mean()\n",
346
+ " }"
347
+ ]
348
+ },
349
+ {
350
+ "cell_type": "code",
351
+ "execution_count": null,
352
+ "metadata": {},
353
+ "outputs": [],
354
+ "source": [
355
+ "def train_epoch(model, dataloader, optimizer, loss_fn, device):\n",
356
+ " model.train()\n",
357
+ " metrics = {'total': 0, 'mse': 0, 'cosine_sim': 0}\n",
358
+ " \n",
359
+ " for batch in tqdm(dataloader, desc=\"Training\"):\n",
360
+ " # Forward\n",
361
+ " predicted = model(batch['states'], batch['actions'])\n",
362
+ " target = model.get_target_embedding(batch['next_states'])\n",
363
+ " \n",
364
+ " # Loss\n",
365
+ " losses = loss_fn(predicted, target)\n",
366
+ " \n",
367
+ " # Backward\n",
368
+ " optimizer.zero_grad()\n",
369
+ " losses['total'].backward()\n",
370
+ " torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
371
+ " optimizer.step()\n",
372
+ " \n",
373
+ " # Track\n",
374
+ " for k, v in losses.items():\n",
375
+ " if k in metrics:\n",
376
+ " metrics[k] += v.item()\n",
377
+ " \n",
378
+ " return {k: v / len(dataloader) for k, v in metrics.items()}\n",
379
+ "\n",
380
+ "\n",
381
+ "def validate(model, dataloader, loss_fn, device):\n",
382
+ " model.eval()\n",
383
+ " metrics = {'total': 0, 'mse': 0, 'cosine_sim': 0}\n",
384
+ " \n",
385
+ " with torch.no_grad():\n",
386
+ " for batch in dataloader:\n",
387
+ " predicted = model(batch['states'], batch['actions'])\n",
388
+ " target = model.get_target_embedding(batch['next_states'])\n",
389
+ " losses = loss_fn(predicted, target)\n",
390
+ " \n",
391
+ " for k, v in losses.items():\n",
392
+ " if k in metrics:\n",
393
+ " metrics[k] += v.item()\n",
394
+ " \n",
395
+ " return {k: v / len(dataloader) for k, v in metrics.items()}"
396
+ ]
397
+ },
398
+ {
399
+ "cell_type": "code",
400
+ "execution_count": null,
401
+ "metadata": {},
402
+ "outputs": [],
403
+ "source": [
404
+ "# Initialize\n",
405
+ "model = JEPAWorldModelV2(\n",
406
+ " model_name=\"gpt2\",\n",
407
+ " state_dim=256,\n",
408
+ " freeze_llm=True\n",
409
+ ").to(device)\n",
410
+ "\n",
411
+ "loss_fn = JEPALoss(mse_weight=1.0, cosine_weight=0.5)\n",
412
+ "\n",
413
+ "optimizer = torch.optim.AdamW(\n",
414
+ " filter(lambda p: p.requires_grad, model.parameters()),\n",
415
+ " lr=3e-4,\n",
416
+ " weight_decay=0.01\n",
417
+ ")\n",
418
+ "\n",
419
+ "# Count params\n",
420
+ "trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
421
+ "total = sum(p.numel() for p in model.parameters())\n",
422
+ "print(f\"Trainable: {trainable:,} / Total: {total:,}\")"
423
+ ]
424
+ },
425
+ {
426
+ "cell_type": "code",
427
+ "execution_count": null,
428
+ "metadata": {},
429
+ "outputs": [],
430
+ "source": [
431
+ "# Train\n",
432
+ "num_epochs = 30\n",
433
+ "history = {'train_loss': [], 'val_loss': [], 'train_sim': [], 'val_sim': []}\n",
434
+ "\n",
435
+ "for epoch in range(num_epochs):\n",
436
+ " train_metrics = train_epoch(model, train_loader, optimizer, loss_fn, device)\n",
437
+ " val_metrics = validate(model, val_loader, loss_fn, device)\n",
438
+ " \n",
439
+ " history['train_loss'].append(train_metrics['total'])\n",
440
+ " history['val_loss'].append(val_metrics['total'])\n",
441
+ " history['train_sim'].append(train_metrics['cosine_sim'])\n",
442
+ " history['val_sim'].append(val_metrics['cosine_sim'])\n",
443
+ " \n",
444
+ " if (epoch + 1) % 5 == 0:\n",
445
+ " print(f\"Epoch {epoch+1}/{num_epochs}\")\n",
446
+ " print(f\" Train Loss: {train_metrics['total']:.4f}, Cosine Sim: {train_metrics['cosine_sim']:.4f}\")\n",
447
+ " print(f\" Val Loss: {val_metrics['total']:.4f}, Cosine Sim: {val_metrics['cosine_sim']:.4f}\")"
448
+ ]
449
+ },
450
+ {
451
+ "cell_type": "code",
452
+ "execution_count": null,
453
+ "metadata": {},
454
+ "outputs": [],
455
+ "source": [
456
+ "# Plot\n",
457
+ "fig, axes = plt.subplots(1, 2, figsize=(12, 4))\n",
458
+ "\n",
459
+ "axes[0].plot(history['train_loss'], label='Train')\n",
460
+ "axes[0].plot(history['val_loss'], label='Val')\n",
461
+ "axes[0].set_xlabel('Epoch')\n",
462
+ "axes[0].set_ylabel('Loss')\n",
463
+ "axes[0].legend()\n",
464
+ "axes[0].set_title('JEPA Loss')\n",
465
+ "\n",
466
+ "axes[1].plot(history['train_sim'], label='Train')\n",
467
+ "axes[1].plot(history['val_sim'], label='Val')\n",
468
+ "axes[1].set_xlabel('Epoch')\n",
469
+ "axes[1].set_ylabel('Cosine Similarity')\n",
470
+ "axes[1].legend()\n",
471
+ "axes[1].set_title('Prediction Quality')\n",
472
+ "\n",
473
+ "plt.tight_layout()\n",
474
+ "plt.show()"
475
+ ]
476
+ },
477
+ {
478
+ "cell_type": "markdown",
479
+ "metadata": {},
480
+ "source": [
481
+ "## 5. Test: Predict Action Consequences"
482
+ ]
483
+ },
484
+ {
485
+ "cell_type": "code",
486
+ "execution_count": null,
487
+ "metadata": {},
488
+ "outputs": [],
489
+ "source": [
490
+ "def predict_outcome(model, state: str, action: str, candidates: list) -> dict:\n",
491
+ " \"\"\"\n",
492
+ " JEPA-style inference:\n",
493
+ " 1. Predict next state embedding\n",
494
+ " 2. Find closest candidate in embedding space\n",
495
+ " \"\"\"\n",
496
+ " model.eval()\n",
497
+ " \n",
498
+ " with torch.no_grad():\n",
499
+ " # Predict next state embedding\n",
500
+ " predicted_emb = model([state], [action]) # [1, state_dim]\n",
501
+ " \n",
502
+ " # Encode all candidates\n",
503
+ " candidate_embs = model.encode_state(candidates) # [N, state_dim]\n",
504
+ " \n",
505
+ " # Compute similarities\n",
506
+ " sims = F.cosine_similarity(\n",
507
+ " predicted_emb.expand(len(candidates), -1),\n",
508
+ " candidate_embs\n",
509
+ " )\n",
510
+ " \n",
511
+ " best_idx = sims.argmax().item()\n",
512
+ " \n",
513
+ " return {\n",
514
+ " 'prediction': candidates[best_idx],\n",
515
+ " 'confidence': sims[best_idx].item(),\n",
516
+ " 'all_scores': {c: sims[i].item() for i, c in enumerate(candidates)}\n",
517
+ " }\n",
518
+ "\n",
519
+ "\n",
520
+ "# Test cases\n",
521
+ "test_cases = [\n",
522
+ " {\n",
523
+ " 'state': \"Document is empty\",\n",
524
+ " 'action': \"create introduction section\",\n",
525
+ " 'candidates': [\n",
526
+ " \"Document has introduction section\",\n",
527
+ " \"Document deleted\",\n",
528
+ " \"Document has 500 words\",\n",
529
+ " \"Cart has 1 item\"\n",
530
+ " ],\n",
531
+ " 'expected': \"Document has introduction section\"\n",
532
+ " },\n",
533
+ " {\n",
534
+ " 'state': \"Cart has 2 items totaling $80\",\n",
535
+ " 'action': \"apply 10% discount\",\n",
536
+ " 'candidates': [\n",
537
+ " \"Cart has 2 items totaling $72\",\n",
538
+ " \"Cart is empty\",\n",
539
+ " \"Cart has 3 items totaling $100\",\n",
540
+ " \"Order placed\"\n",
541
+ " ],\n",
542
+ " 'expected': \"Cart has 2 items totaling $72\"\n",
543
+ " },\n",
544
+ " {\n",
545
+ " 'state': \"Project has 5 pending tasks\",\n",
546
+ " 'action': \"start task 1\",\n",
547
+ " 'candidates': [\n",
548
+ " \"Project has 1 active and 4 pending tasks\",\n",
549
+ " \"Project has 5 done tasks\",\n",
550
+ " \"Project has no tasks\",\n",
551
+ " \"Document approved\"\n",
552
+ " ],\n",
553
+ " 'expected': \"Project has 1 active and 4 pending tasks\"\n",
554
+ " }\n",
555
+ "]\n",
556
+ "\n",
557
+ "print(\"=\"*80)\n",
558
+ "print(\"WORLD MODEL PREDICTIONS\")\n",
559
+ "print(\"=\"*80)\n",
560
+ "\n",
561
+ "correct = 0\n",
562
+ "for i, test in enumerate(test_cases):\n",
563
+ " result = predict_outcome(model, test['state'], test['action'], test['candidates'])\n",
564
+ " is_correct = result['prediction'] == test['expected']\n",
565
+ " correct += is_correct\n",
566
+ " \n",
567
+ " print(f\"\\nTest {i+1}: {'✓' if is_correct else '✗'}\")\n",
568
+ " print(f\" State: {test['state']}\")\n",
569
+ " print(f\" Action: {test['action']}\")\n",
570
+ " print(f\" Predicted: {result['prediction']}\")\n",
571
+ " print(f\" Expected: {test['expected']}\")\n",
572
+ " print(f\" Confidence: {result['confidence']:.4f}\")\n",
573
+ "\n",
574
+ "print(f\"\\nAccuracy: {correct}/{len(test_cases)}\")"
575
+ ]
576
+ },
577
+ {
578
+ "cell_type": "markdown",
579
+ "metadata": {},
580
+ "source": [
581
+ "## 6. Visualize State Embedding Space"
582
+ ]
583
+ },
584
+ {
585
+ "cell_type": "code",
586
+ "execution_count": null,
587
+ "metadata": {},
588
+ "outputs": [],
589
+ "source": [
590
+ "from sklearn.manifold import TSNE\n",
591
+ "\n",
592
+ "# Encode a variety of states\n",
593
+ "states_to_visualize = [\n",
594
+ " # Document states\n",
595
+ " \"Document is empty\",\n",
596
+ " \"Document has introduction section\",\n",
597
+ " \"Document has introduction with 500 words\",\n",
598
+ " \"Document pending review\",\n",
599
+ " \"Document approved\",\n",
600
+ " # Cart states\n",
601
+ " \"Cart empty\",\n",
602
+ " \"Cart has 1 item totaling $50\",\n",
603
+ " \"Cart has 2 items totaling $80\",\n",
604
+ " \"Order placed for $72\",\n",
605
+ " # Project states\n",
606
+ " \"Project has no tasks\",\n",
607
+ " \"Project has 5 pending tasks\",\n",
608
+ " \"Project complete with 5 done tasks\",\n",
609
+ "]\n",
610
+ "\n",
611
+ "categories = ['doc']*5 + ['cart']*4 + ['project']*3\n",
612
+ "\n",
613
+ "model.eval()\n",
614
+ "with torch.no_grad():\n",
615
+ " embeddings = model.encode_state(states_to_visualize).cpu().numpy()\n",
616
+ "\n",
617
+ "# t-SNE\n",
618
+ "tsne = TSNE(n_components=2, perplexity=5, random_state=42)\n",
619
+ "emb_2d = tsne.fit_transform(embeddings)\n",
620
+ "\n",
621
+ "# Plot\n",
622
+ "plt.figure(figsize=(10, 8))\n",
623
+ "colors = {'doc': 'blue', 'cart': 'green', 'project': 'red'}\n",
624
+ "\n",
625
+ "for i, (x, y) in enumerate(emb_2d):\n",
626
+ " plt.scatter(x, y, c=colors[categories[i]], s=100)\n",
627
+ " plt.annotate(states_to_visualize[i][:30] + '...', (x, y), fontsize=8)\n",
628
+ "\n",
629
+ "plt.title(\"State Embedding Space (t-SNE)\")\n",
630
+ "plt.xlabel(\"Dimension 1\")\n",
631
+ "plt.ylabel(\"Dimension 2\")\n",
632
+ "plt.tight_layout()\n",
633
+ "plt.show()"
634
+ ]
635
+ },
636
+ {
637
+ "cell_type": "markdown",
638
+ "metadata": {},
639
+ "source": [
640
+ "## 7. Save Model"
641
+ ]
642
+ },
643
+ {
644
+ "cell_type": "code",
645
+ "execution_count": null,
646
+ "metadata": {},
647
+ "outputs": [],
648
+ "source": [
649
+ "# Save only the trained components (not the frozen LLM)\n",
650
+ "torch.save({\n",
651
+ " 'state_encoder': model.state_encoder.state_dict(),\n",
652
+ " 'action_encoder': model.action_encoder.state_dict(),\n",
653
+ " 'dynamics_predictor': model.dynamics_predictor.state_dict(),\n",
654
+ " 'config': {\n",
655
+ " 'model_name': 'gpt2',\n",
656
+ " 'state_dim': 256\n",
657
+ " }\n",
658
+ "}, 'jepa_world_model_option2.pt')\n",
659
+ "\n",
660
+ "print(\"Model saved!\")"
661
+ ]
662
+ },
663
+ {
664
+ "cell_type": "markdown",
665
+ "metadata": {},
666
+ "source": [
667
+ "## Summary\n",
668
+ "\n",
669
+ "**Option 2 Advantages over Option 1:**\n",
670
+ "- Single model serves as both encoder and predictor backbone\n",
671
+ "- Shared representation space between states and predictions\n",
672
+ "- Can fine-tune the LLM for even better results\n",
673
+ "\n",
674
+ "**Key Implementation Details:**\n",
675
+ "1. LLM hidden states → mean pooled → state encoder → state embedding\n",
676
+ "2. State + Action embeddings → dynamics predictor → next state embedding\n",
677
+ "3. Loss: MSE + Cosine similarity in embedding space\n",
678
+ "\n",
679
+ "**This is JEPA because:**\n",
680
+ "- We predict embeddings, not tokens\n",
681
+ "- The model learns state dynamics, not text generation\n",
682
+ "- Planning = finding actions that lead to desired state embeddings"
683
+ ]
684
+ }
685
+ ],
686
+ "metadata": {
687
+ "kernelspec": {
688
+ "display_name": "Python 3",
689
+ "language": "python",
690
+ "name": "python3"
691
+ },
692
+ "language_info": {
693
+ "name": "python",
694
+ "version": "3.10.0"
695
+ }
696
+ },
697
+ "nbformat": 4,
698
+ "nbformat_minor": 4
699
+ }