ghrua commited on
Commit
66bef3c
·
1 Parent(s): 60e8ccb

update with 1) compatibility with phones; 2) focus view; 3) cache more inputs

Browse files
Files changed (2) hide show
  1. app.py +37 -15
  2. templates/index.html +88 -40
app.py CHANGED
@@ -7,6 +7,7 @@ import queue
7
  import threading
8
  from concurrent.futures import Future
9
  from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
 
10
 
11
  app = Flask(__name__)
12
  CKPT_NAME = "SakanaAI/RePo-OLMo2-1B-stage2-L5"
@@ -31,9 +32,8 @@ class RePo:
31
  self.device = torch.device("cpu")
32
  self.model = model
33
  self.start_layer = start_layer
34
- self.prev = None
35
- self.prev_indices = None
36
- self.prev_tok = None
37
 
38
  @torch.no_grad()
39
  def forward(self, prompt, layer, head, max_tokens=512):
@@ -48,9 +48,9 @@ class RePo:
48
  inputs['attention_mask'] = inputs['attention_mask'][:, :max_tokens]
49
  prompt = self.tokenizer.decode(inputs['input_ids'][0], skip_special_tokens=False)
50
 
51
- if self.prev == prompt:
52
- pred_indices = self.prev_indices
53
- toks = self.prev_tok
54
  else:
55
  inputs = self.tokenizer(prompt, return_tensors="pt")
56
  tok_ids = inputs['input_ids']
@@ -61,9 +61,9 @@ class RePo:
61
  outputs = self.model(**inputs, return_dict=True, output_pred_indices=True)
62
  pred_indices = outputs.pred_indices
63
  pred_indices = [it.data.squeeze(0).reshape(-1, n_toks).tolist() for it in pred_indices]
64
- self.prev = prompt
65
- self.prev_indices = pred_indices
66
- self.prev_tok = toks
67
 
68
  data = []
69
  # Safety check for layer bounds
@@ -126,10 +126,7 @@ def process_sentence():
126
  layer = int(req_data.get('layer', 5))
127
  head = int(req_data.get('head', 0))
128
 
129
- # Create a Future object to communicate between threads
130
  future = Future()
131
-
132
- # Push job to queue
133
  execution_queue.put((future, {
134
  'sentence': sentence,
135
  'layer': layer,
@@ -137,14 +134,39 @@ def process_sentence():
137
  'max_tokens': 512
138
  }))
139
 
140
- # Wait for the result (This blocks the HTTP request until the worker finishes)
141
- # We can add a timeout here if desired (e.g., future.result(timeout=60))
142
  results, was_truncated = future.result()
143
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  return jsonify({
145
  "status": "success",
146
  "data": results,
147
- "truncated": was_truncated
 
148
  })
149
 
150
  except Exception as e:
 
7
  import threading
8
  from concurrent.futures import Future
9
  from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
10
+ from collections import OrderedDict
11
 
12
  app = Flask(__name__)
13
  CKPT_NAME = "SakanaAI/RePo-OLMo2-1B-stage2-L5"
 
32
  self.device = torch.device("cpu")
33
  self.model = model
34
  self.start_layer = start_layer
35
+ self.cache = OrderedDict()
36
+ self.cache_size = 8
 
37
 
38
  @torch.no_grad()
39
  def forward(self, prompt, layer, head, max_tokens=512):
 
48
  inputs['attention_mask'] = inputs['attention_mask'][:, :max_tokens]
49
  prompt = self.tokenizer.decode(inputs['input_ids'][0], skip_special_tokens=False)
50
 
51
+ if prompt in self.cache:
52
+ pred_indices, toks = self.cache[prompt]
53
+ self.cache.move_to_end(prompt)
54
  else:
55
  inputs = self.tokenizer(prompt, return_tensors="pt")
56
  tok_ids = inputs['input_ids']
 
61
  outputs = self.model(**inputs, return_dict=True, output_pred_indices=True)
62
  pred_indices = outputs.pred_indices
63
  pred_indices = [it.data.squeeze(0).reshape(-1, n_toks).tolist() for it in pred_indices]
64
+ self.cache[prompt] = (pred_indices, toks)
65
+ if len(self.cache) > self.cache_size:
66
+ self.cache.popitem(last=False)
67
 
68
  data = []
69
  # Safety check for layer bounds
 
126
  layer = int(req_data.get('layer', 5))
127
  head = int(req_data.get('head', 0))
128
 
 
129
  future = Future()
 
 
130
  execution_queue.put((future, {
131
  'sentence': sentence,
132
  'layer': layer,
 
134
  'max_tokens': 512
135
  }))
136
 
 
 
137
  results, was_truncated = future.result()
138
 
139
+ # --- OUTLIER DETECTION LOGIC ---
140
+ suggested_range = None
141
+ y_vals = [d['y'] for d in results]
142
+
143
+ # Only apply logic if we have enough data points
144
+ if len(y_vals) > 5:
145
+ # Calculate Quartiles
146
+ q75, q25 = np.percentile(y_vals, [75 ,25])
147
+ iqr = q75 - q25
148
+
149
+ # Define bounds (1.5 * IQR is standard for outliers)
150
+ lower_bound = q25 - (1.5 * iqr)
151
+ upper_bound = q75 + (1.5 * iqr)
152
+
153
+ # Find the actual data range within these bounds
154
+ inliers = [y for y in y_vals if lower_bound <= y <= upper_bound]
155
+
156
+ if inliers:
157
+ # Add 5% padding for visual comfort
158
+ min_in = min(inliers)
159
+ max_in = max(inliers)
160
+ padding = (max_in - min_in) * 0.05
161
+ if padding == 0: padding = 1.0 # Handle flat lines
162
+
163
+ suggested_range = [min_in - padding, max_in + padding]
164
+
165
  return jsonify({
166
  "status": "success",
167
  "data": results,
168
+ "truncated": was_truncated,
169
+ "suggested_range": suggested_range
170
  })
171
 
172
  except Exception as e:
templates/index.html CHANGED
@@ -7,28 +7,22 @@
7
  <script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
8
  <script src="https://cdn.tailwindcss.com"></script>
9
  <style>
10
- /* Float msg style */
11
  #floating-msg {
12
  transition: opacity 0.3s ease;
13
  opacity: 0;
14
  pointer-events: none;
15
  }
16
  #floating-msg.visible { opacity: 1; }
17
-
18
- /* Ensure Plotly fills the container */
19
- .js-plotly-plot, .plot-container {
20
- width: 100% !important;
21
- height: 100% !important;
22
- }
23
  </style>
24
  </head>
25
- <body class="h-screen w-screen overflow-hidden bg-gray-50 text-gray-800 font-sans flex">
26
 
27
- <aside class="w-80 (min-w-[320px]) flex flex-col bg-white border-r border-gray-200 shadow-lg z-20 shrink-0">
28
 
29
- <div class="p-4 border-b border-gray-100">
30
- <div class="flex justify-between items-center mb-2">
31
- <h2 class="text-xl font-bold text-indigo-700">RePo Visualizer</h2>
32
  <div id="queue-indicator" class="hidden bg-blue-50 text-blue-700 text-[10px] font-bold px-2 py-1 rounded-full border border-blue-200 flex items-center gap-1">
33
  <span class="relative flex h-1.5 w-1.5">
34
  <span class="animate-ping absolute inline-flex h-full w-full rounded-full bg-blue-400 opacity-75"></span>
@@ -37,9 +31,8 @@
37
  <span id="queue-text">Wait: 0</span>
38
  </div>
39
  </div>
40
-
41
- <div class="text-xs text-gray-500 bg-gray-50 p-2 rounded border border-gray-100 leading-tight space-y-1">
42
- <p><strong>RePo</strong> aims to restructure context, much like humans do, by assigning non-linear positions in a dense space. It is pre-trained on general data and can effectively capture the internal structure and contextual dependencies of the input.</p>
43
  <div class="flex gap-2 pt-1 border-t border-gray-200 mt-1">
44
  <span>📚</span>
45
  <a href="https://www.arxiv.org/abs/2512.14391" target="_blank" class="text-blue-600 hover:underline">Paper</a>
@@ -48,12 +41,11 @@
48
  </div>
49
  </div>
50
 
51
- <div class="p-4 flex flex-col gap-3 flex-1 overflow-y-auto">
52
-
53
- <div class="flex flex-col flex-1 min-h-[100px]">
54
  <label class="text-xs font-semibold text-gray-700 uppercase mb-1">Input Context</label>
55
  <textarea id="sentenceInput"
56
- class="w-full flex-1 p-2 text-sm border border-gray-300 rounded focus:ring-1 focus:ring-indigo-500 outline-none resize-none bg-gray-50"
57
  placeholder="Enter text...">Below is a log of user scores from a gaming session. Please calculate the arithmetic mean (average) of the 'Score' column.
58
 
59
  | User_ID | Username | Score | Region |
@@ -74,7 +66,7 @@ Answer:
74
  </textarea>
75
  </div>
76
 
77
- <div class="grid grid-cols-2 gap-2">
78
  <div>
79
  <label class="text-xs font-semibold text-gray-700 uppercase mb-1">Layer</label>
80
  <select id="layerSelect" class="w-full p-1.5 text-sm border border-gray-300 rounded bg-white">
@@ -114,28 +106,33 @@ Answer:
114
  </div>
115
  </div>
116
 
117
- <button onclick="fetchData()" id="computeBtn" class="w-full bg-indigo-600 hover:bg-indigo-700 text-white font-bold py-2 px-4 rounded text-sm transition shadow-sm flex justify-center items-center gap-2">
118
  <span>Compute & Visualize</span>
119
  </button>
120
  </div>
121
  </aside>
122
 
123
- <main class="flex-1 flex flex-col relative h-full">
124
 
125
- <div id="truncation-warning" class="hidden w-full bg-yellow-100 border-b border-yellow-200 text-yellow-800 text-xs px-4 py-2 flex items-center justify-between">
126
- <span><strong>Notice:</strong> Input truncated to 512 tokens for performance.</span>
127
  <button onclick="this.parentElement.classList.add('hidden')" class="text-yellow-900 font-bold hover:text-yellow-600">×</button>
128
  </div>
129
 
130
- <div class="flex-1 relative bg-white overflow-hidden p-2">
131
- <button onclick="resetView()" class="absolute top-0 right-4 z-10 bg-white/90 hover:bg-gray-100 text-gray-600 text-xs font-bold py-1 px-3 rounded border border-gray-300 shadow-sm transition">
132
- Reset Selection
133
- </button>
 
 
 
 
 
134
 
135
  <div id="chartDiv" class="w-full h-full"></div>
136
 
137
- <div id="floating-msg" class="absolute top-4 left-1/2 transform -translate-x-1/2 bg-gray-900/90 backdrop-blur-sm text-white px-4 py-3 rounded shadow-xl max-w-lg z-20 text-center pointer-events-none">
138
- <p class="text-[10px] text-gray-400 uppercase tracking-wide mb-1">Local Context (i-k to i+k)</p>
139
  <p id="context-text" class="text-sm font-medium leading-relaxed">Click a point...</p>
140
  </div>
141
  </div>
@@ -143,13 +140,13 @@ Answer:
143
 
144
  <script>
145
  let currentData = [];
 
 
146
  const K = 8;
147
 
148
  document.addEventListener('DOMContentLoaded', () => {
149
  fetchData();
150
  setInterval(updateQueueStatus, 2000);
151
-
152
- // Ensure chart resizes if window resizes
153
  window.onresize = function() {
154
  Plotly.Plots.resize(document.getElementById('chartDiv'));
155
  };
@@ -203,7 +200,22 @@ Answer:
203
 
204
  if (response.ok && result.status === 'success') {
205
  currentData = result.data;
206
- renderChart(currentData);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  document.getElementById('floating-msg').classList.remove('visible');
208
 
209
  const warningEl = document.getElementById('truncation-warning');
@@ -225,7 +237,7 @@ Answer:
225
  }
226
  }
227
 
228
- function renderChart(data) {
229
  const xVals = data.map(d => d.x);
230
  const yVals = data.map(d => d.y);
231
  const tVals = data.map(d => d.t);
@@ -233,15 +245,17 @@ Answer:
233
  const trace = {
234
  x: xVals, y: yVals, text: tVals,
235
  mode: 'markers',
236
- marker: { size: 6, color: '#4f46e5', opacity: 0.6, line: { width: 0 } }, // Indigio-600
237
  type: 'scatter', hoverinfo: 'text+x+y'
238
  };
239
 
240
- // Calculate tighter margins to maximize screen space
241
  const layout = {
242
  title: { text: 'Position Analysis', font: {size: 14} },
243
  xaxis: { title: 'Token Index' },
244
- yaxis: { title: 'Assigned Position' },
 
 
 
245
  hovermode: 'closest', dragmode: 'zoom',
246
  margin: { t: 30, r: 20, b: 40, l: 50 },
247
  autosize: true
@@ -254,7 +268,7 @@ Answer:
254
  plotDiv.on('plotly_click', (data) => {
255
  if(data.points.length > 0) highlightContext(data.points[0].pointIndex);
256
  });
257
- plotDiv.on('plotly_doubleclick', () => setTimeout(resetView, 100));
258
  }
259
 
260
  function highlightContext(centerIndex) {
@@ -266,7 +280,7 @@ Answer:
266
  let contextTokens = [];
267
 
268
  for (let i = start; i <= end; i++) {
269
- colors[i] = '#ef4444'; // Red Highlight
270
  sizes[i] = 12;
271
  contextTokens.push(currentData[i].t);
272
  }
@@ -276,7 +290,7 @@ Answer:
276
  document.getElementById('floating-msg').classList.add('visible');
277
  }
278
 
279
- function resetView() {
280
  if(currentData.length === 0) return;
281
  const n = currentData.length;
282
  Plotly.restyle('chartDiv', {
@@ -285,6 +299,40 @@ Answer:
285
  });
286
  document.getElementById('floating-msg').classList.remove('visible');
287
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
  </script>
289
  </body>
290
  </html>
 
7
  <script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
8
  <script src="https://cdn.tailwindcss.com"></script>
9
  <style>
 
10
  #floating-msg {
11
  transition: opacity 0.3s ease;
12
  opacity: 0;
13
  pointer-events: none;
14
  }
15
  #floating-msg.visible { opacity: 1; }
16
+ .js-plotly-plot, .plot-container { width: 100% !important; height: 100% !important; }
 
 
 
 
 
17
  </style>
18
  </head>
19
+ <body class="h-screen w-screen overflow-hidden bg-gray-50 text-gray-800 font-sans flex flex-col md:flex-row">
20
 
21
+ <aside class="w-full md:w-80 md:h-full max-h-[40vh] md:max-h-full flex flex-col bg-white border-b md:border-b-0 md:border-r border-gray-200 shadow-lg z-20 shrink-0 overflow-hidden">
22
 
23
+ <div class="p-3 md:p-4 border-b border-gray-100 shrink-0">
24
+ <div class="flex justify-between items-center mb-1 md:mb-2">
25
+ <h2 class="text-lg md:text-xl font-bold text-indigo-700">RePo Visualizer</h2>
26
  <div id="queue-indicator" class="hidden bg-blue-50 text-blue-700 text-[10px] font-bold px-2 py-1 rounded-full border border-blue-200 flex items-center gap-1">
27
  <span class="relative flex h-1.5 w-1.5">
28
  <span class="animate-ping absolute inline-flex h-full w-full rounded-full bg-blue-400 opacity-75"></span>
 
31
  <span id="queue-text">Wait: 0</span>
32
  </div>
33
  </div>
34
+ <div class="hidden md:block text-xs text-gray-500 bg-gray-50 p-2 rounded border border-gray-100 leading-tight space-y-1">
35
+ <p><strong>RePo</strong> aims to restructure context by assigning non-linear positions in a dense space.</p>
 
36
  <div class="flex gap-2 pt-1 border-t border-gray-200 mt-1">
37
  <span>📚</span>
38
  <a href="https://www.arxiv.org/abs/2512.14391" target="_blank" class="text-blue-600 hover:underline">Paper</a>
 
41
  </div>
42
  </div>
43
 
44
+ <div class="p-3 md:p-4 flex flex-col gap-3 flex-1 overflow-y-auto">
45
+ <div class="flex flex-col flex-1 min-h-0">
 
46
  <label class="text-xs font-semibold text-gray-700 uppercase mb-1">Input Context</label>
47
  <textarea id="sentenceInput"
48
+ class="w-full flex-1 p-2 text-sm border border-gray-300 rounded focus:ring-1 focus:ring-indigo-500 outline-none resize-none bg-gray-50 min-h-[80px]"
49
  placeholder="Enter text...">Below is a log of user scores from a gaming session. Please calculate the arithmetic mean (average) of the 'Score' column.
50
 
51
  | User_ID | Username | Score | Region |
 
66
  </textarea>
67
  </div>
68
 
69
+ <div class="grid grid-cols-2 gap-2 shrink-0">
70
  <div>
71
  <label class="text-xs font-semibold text-gray-700 uppercase mb-1">Layer</label>
72
  <select id="layerSelect" class="w-full p-1.5 text-sm border border-gray-300 rounded bg-white">
 
106
  </div>
107
  </div>
108
 
109
+ <button onclick="fetchData()" id="computeBtn" class="shrink-0 w-full bg-indigo-600 hover:bg-indigo-700 text-white font-bold py-2 px-4 rounded text-sm transition shadow-sm flex justify-center items-center gap-2">
110
  <span>Compute & Visualize</span>
111
  </button>
112
  </div>
113
  </aside>
114
 
115
+ <main class="flex-1 flex flex-col relative min-h-0 overflow-hidden bg-gray-100">
116
 
117
+ <div id="truncation-warning" class="hidden w-full bg-yellow-100 border-b border-yellow-200 text-yellow-800 text-xs px-4 py-2 flex items-center justify-between z-20">
118
+ <span><strong>Notice:</strong> Input truncated.</span>
119
  <button onclick="this.parentElement.classList.add('hidden')" class="text-yellow-900 font-bold hover:text-yellow-600">×</button>
120
  </div>
121
 
122
+ <div class="flex-1 relative bg-white w-full h-full p-2">
123
+ <div class="absolute top-2 right-4 z-10 flex gap-2">
124
+ <button id="viewToggleBtn" onclick="toggleView()" class="hidden bg-white/90 hover:bg-gray-100 text-indigo-700 text-xs font-bold py-1 px-3 rounded border border-indigo-200 shadow-sm transition">
125
+ Global View
126
+ </button>
127
+ <button onclick="resetHighlight()" class="bg-white/90 hover:bg-gray-100 text-gray-600 text-xs font-bold py-1 px-3 rounded border border-gray-300 shadow-sm transition">
128
+ Clear Highlight
129
+ </button>
130
+ </div>
131
 
132
  <div id="chartDiv" class="w-full h-full"></div>
133
 
134
+ <div id="floating-msg" class="absolute top-10 left-1/2 transform -translate-x-1/2 bg-gray-900/90 backdrop-blur-sm text-white px-4 py-3 rounded shadow-xl max-w-[90%] w-auto z-20 text-center pointer-events-none">
135
+ <p class="text-[10px] text-gray-400 uppercase tracking-wide mb-1">Local Context</p>
136
  <p id="context-text" class="text-sm font-medium leading-relaxed">Click a point...</p>
137
  </div>
138
  </div>
 
140
 
141
  <script>
142
  let currentData = [];
143
+ let savedRange = null; // Store the "Focused" range calculated by backend
144
+ let isFocused = true; // Track current state
145
  const K = 8;
146
 
147
  document.addEventListener('DOMContentLoaded', () => {
148
  fetchData();
149
  setInterval(updateQueueStatus, 2000);
 
 
150
  window.onresize = function() {
151
  Plotly.Plots.resize(document.getElementById('chartDiv'));
152
  };
 
200
 
201
  if (response.ok && result.status === 'success') {
202
  currentData = result.data;
203
+ savedRange = result.suggested_range; // Save range
204
+
205
+ // Logic to set initial state
206
+ const toggleBtn = document.getElementById('viewToggleBtn');
207
+
208
+ if (savedRange) {
209
+ isFocused = true;
210
+ toggleBtn.classList.remove('hidden'); // Show button only if outliers exist
211
+ renderChart(currentData, savedRange);
212
+ updateToggleButtonText();
213
+ } else {
214
+ isFocused = false;
215
+ toggleBtn.classList.add('hidden'); // Hide button if no outliers
216
+ renderChart(currentData, null);
217
+ }
218
+
219
  document.getElementById('floating-msg').classList.remove('visible');
220
 
221
  const warningEl = document.getElementById('truncation-warning');
 
237
  }
238
  }
239
 
240
+ function renderChart(data, yRange) {
241
  const xVals = data.map(d => d.x);
242
  const yVals = data.map(d => d.y);
243
  const tVals = data.map(d => d.t);
 
245
  const trace = {
246
  x: xVals, y: yVals, text: tVals,
247
  mode: 'markers',
248
+ marker: { size: 6, color: '#4f46e5', opacity: 0.6, line: { width: 0 } },
249
  type: 'scatter', hoverinfo: 'text+x+y'
250
  };
251
 
 
252
  const layout = {
253
  title: { text: 'Position Analysis', font: {size: 14} },
254
  xaxis: { title: 'Token Index' },
255
+ yaxis: {
256
+ title: 'Assigned Position',
257
+ range: yRange ? yRange : null
258
+ },
259
  hovermode: 'closest', dragmode: 'zoom',
260
  margin: { t: 30, r: 20, b: 40, l: 50 },
261
  autosize: true
 
268
  plotDiv.on('plotly_click', (data) => {
269
  if(data.points.length > 0) highlightContext(data.points[0].pointIndex);
270
  });
271
+ plotDiv.on('plotly_doubleclick', () => setTimeout(resetHighlight, 100));
272
  }
273
 
274
  function highlightContext(centerIndex) {
 
280
  let contextTokens = [];
281
 
282
  for (let i = start; i <= end; i++) {
283
+ colors[i] = '#ef4444';
284
  sizes[i] = 12;
285
  contextTokens.push(currentData[i].t);
286
  }
 
290
  document.getElementById('floating-msg').classList.add('visible');
291
  }
292
 
293
+ function resetHighlight() {
294
  if(currentData.length === 0) return;
295
  const n = currentData.length;
296
  Plotly.restyle('chartDiv', {
 
299
  });
300
  document.getElementById('floating-msg').classList.remove('visible');
301
  }
302
+
303
+ // Toggle View Logic
304
+ function toggleView() {
305
+ if (!savedRange) return; // Should not happen if button is hidden, but safety first
306
+
307
+ if (isFocused) {
308
+ // Currently Focused, User wants Global
309
+ Plotly.relayout('chartDiv', {
310
+ 'yaxis.autorange': true
311
+ });
312
+ isFocused = false;
313
+ } else {
314
+ // Currently Global, User wants Focused
315
+ Plotly.relayout('chartDiv', {
316
+ 'yaxis.range': savedRange,
317
+ 'yaxis.autorange': false
318
+ });
319
+ isFocused = true;
320
+ }
321
+ updateToggleButtonText();
322
+ }
323
+
324
+ function updateToggleButtonText() {
325
+ const btn = document.getElementById('viewToggleBtn');
326
+ if (isFocused) {
327
+ btn.innerText = "Global View";
328
+ btn.classList.add('text-indigo-700', 'border-indigo-200');
329
+ btn.classList.remove('text-gray-600', 'border-gray-300');
330
+ } else {
331
+ btn.innerText = "Focused View";
332
+ btn.classList.add('text-gray-600', 'border-gray-300');
333
+ btn.classList.remove('text-indigo-700', 'border-indigo-200');
334
+ }
335
+ }
336
  </script>
337
  </body>
338
  </html>