SixOpen commited on
Commit
55202d4
·
verified ·
1 Parent(s): dc76df4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +923 -0
app.py ADDED
@@ -0,0 +1,923 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import random
4
+ import time
5
+ import urllib.request
6
+
7
+ import gradio as gr
8
+ import spaces
9
+ import torch
10
+ import torch.nn.functional as F
11
+ import triton
12
+ import triton.language as tl
13
+ from transformers import AutoModel, AutoTokenizer
14
+
15
+
16
+ MODEL_ID = "SixOpen/HARE"
17
+
18
+ model = AutoModel.from_pretrained(MODEL_ID, trust_remote_code=True).eval()
19
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
20
+
21
+
22
+ @triton.jit
23
+ def _wkv7_fwd_kernel(
24
+ R, K, V, DECAY, A, O,
25
+ STATE_OUT, STATE_IN,
26
+ sab_scale, T,
27
+ stride_b, stride_t, stride_h,
28
+ H: tl.constexpr, D: tl.constexpr, BLOCK_D: tl.constexpr,
29
+ RETURN_STATE: tl.constexpr, HAS_INIT_STATE: tl.constexpr,
30
+ ):
31
+ pid = tl.program_id(0)
32
+ b_idx = pid // H
33
+ h_idx = pid % H
34
+ base = b_idx * stride_b + h_idx * stride_h
35
+
36
+ di = tl.arange(0, BLOCK_D)
37
+ dj = tl.arange(0, BLOCK_D)
38
+ mask_i = di < D
39
+ mask_j = dj < D
40
+
41
+ if HAS_INIT_STATE:
42
+ s_off = b_idx * (H * D * D) + h_idx * (D * D)
43
+ state_ptrs = STATE_IN + s_off + di[:, None] * D + dj[None, :]
44
+ state_mask = mask_i[:, None] & mask_j[None, :]
45
+ state = tl.load(state_ptrs, mask=state_mask, other=0.0).to(tl.float32)
46
+ else:
47
+ state = tl.zeros((BLOCK_D, BLOCK_D), dtype=tl.float32)
48
+
49
+ for t in range(T):
50
+ t_off = base + t * stride_t
51
+ kt = tl.load(K + t_off + dj, mask=mask_j, other=0.0).to(tl.float32)
52
+ vt = tl.load(V + t_off + di, mask=mask_i, other=0.0).to(tl.float32)
53
+ rt = tl.load(R + t_off + dj, mask=mask_j, other=0.0).to(tl.float32)
54
+ dt = tl.load(DECAY + t_off + dj, mask=mask_j, other=1.0).to(tl.float32)
55
+ at = tl.load(A + t_off + dj, mask=mask_j, other=0.0).to(tl.float32)
56
+
57
+ sa = tl.sum(state * (-kt)[None, :], axis=1)
58
+ ka = kt * at
59
+ sab = sa[:, None] * ka[None, :]
60
+ state = state * dt[None, :] + sab_scale * sab + vt[:, None] * kt[None, :]
61
+ state = tl.minimum(tl.maximum(state, -10.0), 10.0)
62
+
63
+ out_t = tl.sum(state * rt[None, :], axis=1)
64
+ tl.store(O + t_off + di, out_t, mask=mask_i)
65
+
66
+ if RETURN_STATE:
67
+ s_off = b_idx * (H * D * D) + h_idx * (D * D)
68
+ state_ptrs = STATE_OUT + s_off + di[:, None] * D + dj[None, :]
69
+ state_mask = mask_i[:, None] & mask_j[None, :]
70
+ tl.store(state_ptrs, state, mask=state_mask)
71
+
72
+
73
+ def wkv7_scan_triton(r, decay, k, v, a, sab_scale, return_state=False, init_state=None):
74
+ B, T, H, D = r.shape
75
+ r, k, v, decay, a = [x.contiguous() for x in (r, k, v, decay, a)]
76
+ o = torch.empty_like(r)
77
+ state_out = None
78
+ if return_state:
79
+ state_out = torch.empty(B, H, D, D, dtype=torch.float32, device=r.device)
80
+ has_init = init_state is not None
81
+ if has_init:
82
+ init_state = init_state.contiguous().float()
83
+ stride_b = T * H * D
84
+ stride_t = H * D
85
+ stride_h = D
86
+ BLOCK_D = triton.next_power_of_2(D)
87
+ _wkv7_fwd_kernel[(B * H,)](
88
+ r, k, v, decay, a, o,
89
+ state_out, init_state,
90
+ float(sab_scale), T,
91
+ stride_b, stride_t, stride_h,
92
+ H=H, D=D, BLOCK_D=BLOCK_D,
93
+ RETURN_STATE=return_state,
94
+ HAS_INIT_STATE=has_init,
95
+ )
96
+ if return_state:
97
+ return o, state_out
98
+ return o
99
+
100
+
101
+ def find_birwkv_layers(model):
102
+ layers = []
103
+ ids = {}
104
+ for m in model.modules():
105
+ if type(m).__name__ == 'BiRWKV7Layer':
106
+ ids[id(m)] = len(layers)
107
+ layers.append(m)
108
+ return layers, ids
109
+
110
+
111
+ class SpanEncoder:
112
+
113
+ def __init__(self, model, tokenizer, chunk_size=512):
114
+ self.model = model
115
+ self.tokenizer = tokenizer
116
+ self.device = next(model.parameters()).device
117
+ self.chunk_size = chunk_size
118
+
119
+ self.birwkv_layers, self.birwkv_ids = find_birwkv_layers(model)
120
+ self._originals = {}
121
+ self._hooked = False
122
+ self._active_states = [None] * len(self.birwkv_layers)
123
+ self.span_data = {}
124
+
125
+ def _hook(self):
126
+ if self._hooked:
127
+ return
128
+ for layer in self.birwkv_layers:
129
+ self._originals[id(layer)] = layer.forward
130
+ layer.forward = self._make_fwd(layer)
131
+ self._hooked = True
132
+
133
+ def _unhook(self):
134
+ if not self._hooked:
135
+ return
136
+ for layer in self.birwkv_layers:
137
+ layer.forward = self._originals[id(layer)]
138
+ self._originals.clear()
139
+ self._hooked = False
140
+
141
+ def _make_fwd(self, layer):
142
+ enc = self
143
+ idx = self.birwkv_ids[id(layer)]
144
+
145
+ def fwd(x, attention_mask=None, **kwargs):
146
+ B, T, C_ = x.shape
147
+ H, D = layer.num_heads, layer.head_size
148
+ prev = enc._active_states[idx]
149
+ if prev is not None:
150
+ x_prev = torch.cat([prev['last_x'], x[:, :-1]], dim=1)
151
+ else:
152
+ x_prev = F.pad(x[:, :-1], (0, 0, 1, 0))
153
+
154
+ def mix(mu):
155
+ return x + (x_prev - x) * torch.sigmoid(mu)
156
+
157
+ r = layer.W_r(mix(layer.mu_r)).view(B, T, H, D)
158
+ w = layer.W_w(mix(layer.mu_w)).view(B, T, H, D)
159
+ k = layer.W_k(mix(layer.mu_k)).view(B, T, H, D)
160
+ v = layer.W_v(mix(layer.mu_v)).view(B, T, H, D)
161
+ a = layer.W_a(mix(layer.mu_a)).view(B, T, H, D)
162
+ g = torch.sigmoid(layer.W_g(mix(layer.mu_g)))
163
+ sab_scale = torch.sigmoid(layer.sab_gate)
164
+ init_st = prev['wkv_state'] if prev else None
165
+
166
+ r_f, k_f, v_f = r.float(), k.float() * (D ** -0.5), v.float()
167
+ a_f = torch.sigmoid(a.float())
168
+ decay = torch.exp(-0.6065306597633104 * torch.sigmoid(w.float()))
169
+ out_fwd, wkv_state = wkv7_scan_triton(
170
+ r_f, decay, k_f, v_f, a_f, sab_scale,
171
+ return_state=True, init_state=init_st)
172
+ out_bwd = wkv7_scan_triton(
173
+ r_f.flip(1), decay.flip(1), k_f.flip(1),
174
+ v_f.flip(1), a_f.flip(1), sab_scale,
175
+ return_state=False).flip(1)
176
+
177
+ enc._active_states[idx] = {
178
+ 'wkv_state': wkv_state,
179
+ 'last_x': x[:, -1:].detach().clone(),
180
+ }
181
+ out = ((out_fwd + out_bwd) * 0.5).reshape(B, T, C_)
182
+ out = layer.group_norm(out.transpose(1, 2)).transpose(1, 2)
183
+ out = layer.W_o(out * g)
184
+ return out, None
185
+ return fwd
186
+
187
+ @torch.no_grad()
188
+ def _forward_encode_raw(self, text, init_states=None, max_length=8192):
189
+ self._hook()
190
+ if init_states is not None:
191
+ self._active_states = [
192
+ {k: v.clone() for k, v in s.items()} if s else None
193
+ for s in init_states
194
+ ]
195
+ else:
196
+ self._active_states = [None] * len(self.birwkv_layers)
197
+
198
+ enc = self.tokenizer(text, return_tensors='pt', truncation=True,
199
+ max_length=max_length)
200
+ ids = enc['input_ids'].to(self.device)
201
+ mask = enc['attention_mask'].to(self.device)
202
+
203
+ h = self.model(input_ids=ids, attention_mask=mask).last_hidden_state
204
+ content = h[0, 1:-1, :].cpu()
205
+ n_content = content.shape[0]
206
+
207
+ final_states = [
208
+ {k: v.clone() for k, v in s.items()} if s else None
209
+ for s in self._active_states
210
+ ]
211
+ self._unhook()
212
+ return content, n_content, final_states
213
+
214
+ def _chunk_hidden(self, content, return_residual=False):
215
+ T = content.shape[0]
216
+ chunks = []
217
+ last_end = 0
218
+ for start in range(0, T, self.chunk_size):
219
+ end = min(start + self.chunk_size, T)
220
+ if end - start < 32:
221
+ break
222
+ emb = F.normalize(content[start:end].mean(0, keepdim=True),
223
+ p=2, dim=-1)
224
+ chunks.append(emb)
225
+ last_end = end
226
+ if not chunks and T > 0:
227
+ chunks.append(F.normalize(content.mean(0, keepdim=True),
228
+ p=2, dim=-1))
229
+ last_end = T
230
+ if return_residual:
231
+ residual = content[last_end:] if last_end < T else None
232
+ return chunks, residual
233
+ return chunks
234
+
235
+ @torch.no_grad()
236
+ def encode_query(self, query):
237
+ assert not self._hooked
238
+ enc = self.tokenizer(query, return_tensors='pt', truncation=True,
239
+ max_length=512)
240
+ ids = enc['input_ids'].to(self.device)
241
+ mask = enc['attention_mask'].to(self.device)
242
+ h = self.model(input_ids=ids, attention_mask=mask).last_hidden_state
243
+ m = mask.unsqueeze(-1).float()
244
+ emb = (h * m).sum(1) / m.sum(1).clamp(min=1e-9)
245
+ return F.normalize(emb, p=2, dim=-1).cpu()
246
+
247
+ def encode_span(self, text, key):
248
+ content, n_tok, states = self._forward_encode_raw(text)
249
+ chunks, residual = self._chunk_hidden(content, return_residual=True)
250
+ self.span_data[key] = {
251
+ 'layer_states': states,
252
+ 'chunk_embs': chunks,
253
+ 'n_tokens': n_tok,
254
+ 'residual_hidden': residual,
255
+ }
256
+ return n_tok
257
+
258
+ def extend_right(self, piece_text, old_key, new_key):
259
+ old = self.span_data.pop(old_key)
260
+ content, n_new, states = self._forward_encode_raw(
261
+ piece_text, init_states=old['layer_states'])
262
+ if old.get('residual_hidden') is not None:
263
+ content = torch.cat([old['residual_hidden'], content], dim=0)
264
+ new_chunks, residual = self._chunk_hidden(
265
+ content, return_residual=True)
266
+ self.span_data[new_key] = {
267
+ 'layer_states': states,
268
+ 'chunk_embs': old['chunk_embs'] + new_chunks,
269
+ 'n_tokens': old['n_tokens'] + n_new,
270
+ 'residual_hidden': residual,
271
+ }
272
+ return n_new
273
+
274
+ def smart_merge(self, new_text, left_key, new_key):
275
+ left = self.span_data.pop(left_key)
276
+ self.remove_old(new_key)
277
+ content, n_new, states = self._forward_encode_raw(
278
+ new_text, init_states=left['layer_states'])
279
+ if left.get('residual_hidden') is not None:
280
+ content = torch.cat([left['residual_hidden'], content], dim=0)
281
+ new_chunks, residual = self._chunk_hidden(
282
+ content, return_residual=True)
283
+ self.span_data[new_key] = {
284
+ 'layer_states': states,
285
+ 'chunk_embs': left['chunk_embs'] + new_chunks,
286
+ 'n_tokens': left['n_tokens'] + n_new,
287
+ 'residual_hidden': residual,
288
+ }
289
+ return n_new
290
+
291
+ def remove_old(self, new_key):
292
+ s, e = new_key
293
+ for old in list(self.span_data.keys()):
294
+ if old[0] >= s and old[1] <= e:
295
+ del self.span_data[old]
296
+
297
+ def search(self, q_emb, spans, top_k=5):
298
+ results = []
299
+ for s, e, text in spans:
300
+ key = (s, e)
301
+ data = self.span_data.get(key)
302
+ if not data or not data['chunk_embs']:
303
+ continue
304
+ chunk_mat = torch.cat(data['chunk_embs'], dim=0)
305
+ sims = (q_emb @ chunk_mat.T).squeeze(0)
306
+ if sims.dim() == 0:
307
+ sims = sims.unsqueeze(0)
308
+ max_sim = sims.max().item()
309
+ best_idx = sims.argmax().item()
310
+ n_chunks = len(data['chunk_embs'])
311
+ chars_per_chunk = len(text) // max(n_chunks, 1)
312
+ offset = min(best_idx * chars_per_chunk, len(text) - 1)
313
+ while offset > 0 and text[offset - 1] not in ' \n\t':
314
+ offset -= 1
315
+ preview = text[offset:offset + 300].replace('\n', ' ').strip()
316
+ results.append((s, e, max_sim, preview, data['n_tokens'], n_chunks))
317
+ results.sort(key=lambda x: x[2], reverse=True)
318
+ return results[:top_k]
319
+
320
+
321
+ class TextProvider:
322
+
323
+ def __init__(self, text, piece_size=4096, seed=42):
324
+ self.text = text
325
+ self.piece_size = piece_size
326
+ self.n_pieces = (len(text) + piece_size - 1) // piece_size
327
+ self.received = [False] * self.n_pieces
328
+ rng = random.Random(seed)
329
+ self.arrival_order = list(range(self.n_pieces))
330
+ rng.shuffle(self.arrival_order)
331
+ self.next_idx = 0
332
+
333
+ def poll_pieces(self):
334
+ if self.next_idx >= self.n_pieces:
335
+ return []
336
+ idx = self.arrival_order[self.next_idx]
337
+ self.received[idx] = True
338
+ self.next_idx += 1
339
+ return [idx]
340
+
341
+ def get_spans(self):
342
+ spans = []
343
+ i = 0
344
+ while i < self.n_pieces:
345
+ if self.received[i]:
346
+ j = i
347
+ while j < self.n_pieces and self.received[j]:
348
+ j += 1
349
+ s_byte = i * self.piece_size
350
+ e_byte = min(j * self.piece_size, len(self.text))
351
+ spans.append((i, j, self.text[s_byte:e_byte]))
352
+ i = j
353
+ else:
354
+ i += 1
355
+ return spans
356
+
357
+ def piece_text(self, idx):
358
+ s = idx * self.piece_size
359
+ return self.text[s:min(s + self.piece_size, len(self.text))]
360
+
361
+ def span_text(self, start_piece, end_piece):
362
+ s = start_piece * self.piece_size
363
+ e = min(end_piece * self.piece_size, len(self.text))
364
+ return self.text[s:e]
365
+
366
+ def progress(self):
367
+ return self.next_idx / self.n_pieces
368
+
369
+ def is_complete(self):
370
+ return self.next_idx >= self.n_pieces
371
+
372
+
373
+ FRANKENSTEIN_EXCERPT = """\
374
+ I am by birth a Genevese; and my family is one of the most distinguished \
375
+ of that republic. My ancestors had been for many years counsellors and \
376
+ syndics; and my father had filled several public situations with honour \
377
+ and reputation.
378
+
379
+ When I was thirteen years of age, we all went on a party of pleasure to \
380
+ the baths near Thonon: the inclemency of the weather obliged us to remain \
381
+ a day confined to the inn. In this house I found a volume of the works of \
382
+ Cornelius Agrippa. I opened it with apathy; the theory which he attempts \
383
+ to demonstrate, and the wonderful facts which he relates, soon changed \
384
+ this feeling into enthusiasm. A new light seemed to dawn upon my mind.
385
+
386
+ When I returned home, my first care was to procure the whole works of \
387
+ this author. My father was not scientific, and I was left to struggle \
388
+ with a child's blindness, added to a student's thirst for knowledge. \
389
+ Under the guidance of my new preceptors, I entered with the greatest \
390
+ diligence into the search of the philosopher's stone and the elixir \
391
+ of life. What glory would attend the discovery, if I could banish \
392
+ disease from the human frame, and render man invulnerable to any but \
393
+ a violent death!
394
+
395
+ It was on a dreary night of November that I beheld the accomplishment \
396
+ of my toils. With an anxiety that almost amounted to agony, I collected \
397
+ the instruments of life around me, that I might infuse a spark of being \
398
+ into the lifeless thing that lay at my feet. It was already one in the \
399
+ morning; the rain pattered dismally against the panes, and my candle was \
400
+ nearly burnt out, when, by the glimmer of the half-extinguished light, \
401
+ I saw the dull yellow eye of the creature open; it breathed hard, and \
402
+ a convulsive motion agitated its limbs.
403
+
404
+ How can I describe my emotions at this catastrophe, or how delineate the \
405
+ wretch whom with such infinite pains and care I had endeavoured to form? \
406
+ I had selected his features as beautiful. Beautiful!--Great God! His \
407
+ yellow skin scarcely covered the work of muscles and arteries beneath; \
408
+ his hair was of a lustrous black, and flowing; his teeth of a pearly \
409
+ whiteness; but these luxuriances only formed a more horrid contrast with \
410
+ his watery eyes, that seemed almost of the same colour as the dun white \
411
+ sockets in which they were set, his shrivelled complexion, and straight \
412
+ black lips.
413
+
414
+ I had worked hard for nearly two years, for the sole purpose of infusing \
415
+ life into an inanimate body. For this I had deprived myself of rest and \
416
+ health. I had desired it with an ardour that far exceeded moderation; but \
417
+ now that I had finished, the beauty of the dream vanished, and breathless \
418
+ horror and disgust filled my heart.
419
+
420
+ I did not dare return to the apartment which I inhabited, but felt \
421
+ impelled to hurry on, although drenched by the rain which poured from a \
422
+ black and comfortless sky. I passed the night wretchedly. Morning, \
423
+ dismal and wet, at length dawned, and discovered to my sleepless and \
424
+ aching eyes the church of Ingolstadt, its white steeple and clock, \
425
+ which indicated the sixth hour.
426
+
427
+ "I shall satiate my ardour for destruction," the creature said, "and \
428
+ make you so wretched that the light of day will be hateful to you. I \
429
+ will be with you on your wedding-night." I started forward, and \
430
+ exclaimed, "Villain! before you sign my death-warrant, be sure that \
431
+ you are yourself safe." My rage was without bounds; I would have seized \
432
+ him; but he eluded me, and quitted the house with precipitation.
433
+
434
+ Great God! why did I not then expire! But I am a wretch, and none ever \
435
+ conceived of the horrors of my secret toil, whilst I dabbled among the \
436
+ unhallowed damps of the grave, or tortured the living animal to animate \
437
+ the lifeless clay.
438
+
439
+ I was soon borne away by the waves, and lost in darkness and distance. \
440
+ Immense and rugged mountains of ice often barred up my passage, and I \
441
+ heard the thunder of the ground sea beneath. The cold is excessive, and \
442
+ many of my unfortunate comrades have already found a grave amidst this \
443
+ scene of desolation. Frankenstein! he is not here: I will not rest; I \
444
+ pursue him still over the untrodden snow and frozen ocean.
445
+ """
446
+
447
+ QUICK_DEMOS = {
448
+ "Frankenstein (excerpt)": {
449
+ "text": FRANKENSTEIN_EXCERPT,
450
+ "queries": [
451
+ "the creature opens its eyes for the first time",
452
+ "playing god with science",
453
+ "a threat on the wedding night",
454
+ "a frozen arctic wasteland",
455
+ ],
456
+ "piece_size": 512,
457
+ "sleep": 0.3,
458
+ },
459
+ }
460
+
461
+
462
+ def render_grid(received, n_pieces, highlight=None):
463
+ max_width = 60
464
+ if n_pieces <= max_width:
465
+ cells = []
466
+ for i in range(n_pieces):
467
+ if i == highlight:
468
+ bg = '#00ff41'
469
+ elif received[i]:
470
+ bg = '#28a745'
471
+ else:
472
+ bg = '#3a3a3a'
473
+ cells.append(
474
+ f'<span style="display:inline-block;width:14px;height:22px;'
475
+ f'background:{bg};margin:1px;border-radius:2px"></span>'
476
+ )
477
+ else:
478
+ cells = []
479
+ for col in range(max_width):
480
+ s = col * n_pieces // max_width
481
+ e = (col + 1) * n_pieces // max_width
482
+ ratio = sum(received[s:e]) / max(1, e - s)
483
+ hl = highlight is not None and s <= highlight < e
484
+ if hl:
485
+ bg = '#00ff41'
486
+ elif ratio > 0.8:
487
+ bg = '#28a745'
488
+ elif ratio > 0.3:
489
+ bg = '#17a2b8'
490
+ elif ratio > 0:
491
+ bg = '#6c757d'
492
+ else:
493
+ bg = '#3a3a3a'
494
+ cells.append(
495
+ f'<span style="display:inline-block;width:14px;height:22px;'
496
+ f'background:{bg};margin:1px;border-radius:2px"></span>'
497
+ )
498
+
499
+ n_recv = sum(received)
500
+ pct = n_recv / max(n_pieces, 1) * 100
501
+ grid = ''.join(cells)
502
+ return (
503
+ f'<div style="font-family:monospace;line-height:1.4;padding:8px 0">'
504
+ f'<div style="display:flex;flex-wrap:wrap;gap:0">{grid}</div>'
505
+ f'<div style="margin-top:8px;color:#aaa">'
506
+ f'Piece {n_recv}/{n_pieces} ({pct:.0f}%)</div></div>'
507
+ )
508
+
509
+
510
+ def render_search(results_dict, peak_scores=None):
511
+ if not results_dict:
512
+ return '<p style="color:#888">Waiting for data...</p>'
513
+
514
+ def _score_color(score):
515
+ if score > 0.5:
516
+ return '#28a745'
517
+ elif score > 0.4:
518
+ return '#ffc107'
519
+ return '#aaa'
520
+
521
+ parts = []
522
+ for query, results in results_dict.items():
523
+ peak = peak_scores.get(query) if peak_scores else None
524
+ header = f'&quot;{query}&quot;'
525
+ if peak:
526
+ header += (f' <span style="color:#888;font-size:0.85em">'
527
+ f'(peak: {peak["score"]:.3f})</span>')
528
+ parts.append(
529
+ f'<div style="margin-bottom:16px">'
530
+ f'<div style="font-weight:bold;color:#58a6ff;margin-bottom:6px">'
531
+ f'{header}</div>'
532
+ )
533
+
534
+ cur_best = results[0]['score'] if results else 0
535
+ if peak and peak['score'] > cur_best + 0.01:
536
+ psc = _score_color(peak['score'])
537
+ pp = peak['preview'][:300].replace('<', '&lt;').replace('>', '&gt;')
538
+ parts.append(
539
+ f'<div style="padding:4px 0 4px 12px;border-left:3px solid {psc};'
540
+ f'background:rgba(40,167,69,0.08);margin-bottom:2px">'
541
+ f'<span style="color:{psc};font-weight:bold">{peak["score"]:.3f}</span> '
542
+ f'<span style="color:#888;font-size:0.85em">peak</span><br>'
543
+ f'<span style="color:#ccc;font-size:0.9em">{pp}...</span>'
544
+ f'</div>'
545
+ )
546
+
547
+ if not results:
548
+ parts.append('<div style="color:#888;padding-left:12px">No results yet</div>')
549
+ else:
550
+ for rank, r in enumerate(results[:3], 1):
551
+ sc = _score_color(r['score'])
552
+ preview = r['preview'][:300].replace('<', '&lt;').replace('>', '&gt;')
553
+ parts.append(
554
+ f'<div style="padding:4px 0 4px 12px;border-left:3px solid {sc}">'
555
+ f'<span style="color:{sc};font-weight:bold">{r["score"]:.3f}</span> '
556
+ f'<span style="color:#888">[{r["span"][0]}-{r["span"][1]}]'
557
+ f' ({r["n_chunks"]}ch)</span><br>'
558
+ f'<span style="color:#ccc;font-size:0.9em">{preview}...</span>'
559
+ f'</div>'
560
+ )
561
+ parts.append('</div>')
562
+ return ''.join(parts)
563
+
564
+
565
+ def _state_color(intensity):
566
+ h = int(220 - intensity * 170)
567
+ s = int(20 + intensity * 70)
568
+ light = int(12 + intensity * 38)
569
+ return f'hsl({h},{s}%,{light}%)'
570
+
571
+
572
+ def render_state_viz(state_history, n_layers=14):
573
+ if not state_history:
574
+ return ('<p style="color:#888">Recurrent state evolution will appear '
575
+ 'as pieces are processed...</p>')
576
+
577
+ n_steps = len(state_history)
578
+ cell_w = max(4, min(14, 600 // max(n_steps, 1)))
579
+
580
+ layer_maxes = []
581
+ for li in range(n_layers):
582
+ vals = [state_history[t][li] for t in range(n_steps)
583
+ if li < len(state_history[t])]
584
+ layer_maxes.append(max(vals) if vals else 1.0)
585
+
586
+ rows = []
587
+ for li in range(n_layers):
588
+ cells = []
589
+ for t in range(n_steps):
590
+ if li < len(state_history[t]):
591
+ norm = state_history[t][li]
592
+ intensity = min(norm / max(layer_maxes[li], 1e-6), 1.0)
593
+ cells.append(
594
+ f'<span style="display:inline-block;width:{cell_w}px;'
595
+ f'height:12px;background:{_state_color(intensity)};'
596
+ f'margin:0 1px"></span>')
597
+ rows.append(
598
+ f'<div style="display:flex;align-items:center;margin:0">'
599
+ f'<span style="width:24px;color:#666;font-size:9px;'
600
+ f'text-align:right;margin-right:3px;flex-shrink:0">R{li+1}</span>'
601
+ f'<div style="display:flex">{"".join(cells)}</div>'
602
+ f'</div>')
603
+
604
+ latest = state_history[-1]
605
+ avg_norm = sum(latest) / len(latest) if latest else 0
606
+
607
+ most_active = 0
608
+ max_delta = 0
609
+ if len(state_history) >= 2:
610
+ prev = state_history[-2]
611
+ for li in range(min(len(latest), len(prev))):
612
+ d = abs(latest[li] - prev[li])
613
+ if d > max_delta:
614
+ max_delta = d
615
+ most_active = li
616
+
617
+ legend = ''.join(
618
+ f'<span style="display:inline-block;width:16px;height:8px;'
619
+ f'background:{_state_color(i / 4)};margin:0 1px"></span>'
620
+ for i in range(5))
621
+
622
+ return (
623
+ f'<div style="font-family:monospace;line-height:1.1">'
624
+ f'{"".join(rows)}'
625
+ f'<div style="color:#777;font-size:10px;margin-top:6px">'
626
+ f'{n_layers} RWKV layers \u00d7 {n_steps} pieces | '
627
+ f'Avg state magnitude: {avg_norm:.1f}'
628
+ f'{f" | Most active: R{most_active+1}" if len(state_history) >= 2 else ""}'
629
+ f'</div>'
630
+ f'<div style="color:#666;font-size:9px;margin-top:2px">'
631
+ f'{legend} low \u2192 high state magnitude'
632
+ f'</div></div>')
633
+
634
+
635
+ def load_text(url):
636
+ resp = urllib.request.urlopen(url, timeout=30)
637
+ text = resp.read().decode('utf-8', errors='replace')
638
+ start = text.find('*** START OF')
639
+ if start != -1:
640
+ text = text[text.find('\n', start) + 1:]
641
+ end = text.find('*** END OF')
642
+ if end != -1:
643
+ text = text[:end]
644
+ return text
645
+
646
+
647
+ def streaming_loop(provider, encoder, queries, q_embs, sleep_time=0):
648
+ prev_span_keys = set()
649
+ hare_tokens = 0
650
+ baseline_tokens = 0
651
+ right_extends = 0
652
+ smart_merges = 0
653
+ full_reencodes = 0
654
+ merge_events = 0
655
+ pieces_processed = 0
656
+ piece_queue = []
657
+ peak_scores = {}
658
+ state_history = []
659
+ n_rwkv_layers = len(encoder.birwkv_layers)
660
+
661
+ while not provider.is_complete():
662
+ new_pieces = provider.poll_pieces()
663
+ if new_pieces:
664
+ piece_queue.extend(new_pieces)
665
+ random.shuffle(piece_queue)
666
+
667
+ if not piece_queue:
668
+ continue
669
+
670
+ idx = piece_queue.pop(0)
671
+ provider.received[idx] = True
672
+ pieces_processed += 1
673
+
674
+ new_spans = provider.get_spans()
675
+ new_keys = {(s, e) for s, e, _ in new_spans}
676
+
677
+ for s, e, span_text_val in new_spans:
678
+ key = (s, e)
679
+ if key in prev_span_keys:
680
+ continue
681
+
682
+ right_key = (s, e - 1)
683
+ if right_key in encoder.span_data:
684
+ n = encoder.extend_right(provider.piece_text(e - 1), right_key, key)
685
+ hare_tokens += n
686
+ right_extends += 1
687
+ baseline_tokens += encoder.span_data[key]['n_tokens']
688
+ continue
689
+
690
+ best_left = None
691
+ for (os_, oe) in list(encoder.span_data.keys()):
692
+ if os_ == s and oe < e:
693
+ if best_left is None or oe > best_left[1]:
694
+ best_left = (os_, oe)
695
+
696
+ if best_left:
697
+ new_portion = provider.span_text(best_left[1], e)
698
+ n = encoder.smart_merge(new_portion, best_left, key)
699
+ hare_tokens += n
700
+ smart_merges += 1
701
+ baseline_tokens += encoder.span_data[key]['n_tokens']
702
+ continue
703
+
704
+ encoder.remove_old(key)
705
+ n = encoder.encode_span(span_text_val, key)
706
+ hare_tokens += n
707
+ full_reencodes += 1
708
+ baseline_tokens += n
709
+
710
+ if len(new_keys) < len(prev_span_keys) and pieces_processed > 1:
711
+ merge_events += 1
712
+ prev_span_keys = new_keys
713
+
714
+ total_chunks = sum(len(d['chunk_embs']) for d in encoder.span_data.values())
715
+ eff = baseline_tokens / max(hare_tokens, 1)
716
+
717
+ if encoder.span_data:
718
+ largest_key = max(encoder.span_data.keys(),
719
+ key=lambda k: k[1] - k[0])
720
+ states = encoder.span_data[largest_key].get('layer_states', [])
721
+ norms = []
722
+ for st in states:
723
+ if st is not None and 'wkv_state' in st:
724
+ norms.append(st['wkv_state'].norm().item())
725
+ else:
726
+ norms.append(0.0)
727
+ state_history.append(norms)
728
+
729
+ search_results = {}
730
+ for q in queries:
731
+ results = encoder.search(q_embs[q], new_spans, top_k=3)
732
+ search_results[q] = [
733
+ {'span': (s, e), 'score': sc, 'preview': pv,
734
+ 'n_chunks': nc, 'n_tokens': nt}
735
+ for s, e, sc, pv, nt, nc in results
736
+ ]
737
+ if results:
738
+ top = results[0]
739
+ sc_top = top[2]
740
+ if q not in peak_scores or sc_top > peak_scores[q]['score']:
741
+ peak_scores[q] = {'score': sc_top, 'preview': top[3]}
742
+
743
+ grid_html = render_grid(provider.received, provider.n_pieces, highlight=idx)
744
+ saved = baseline_tokens - hare_tokens
745
+ eff_md = f"**Efficiency: {eff:.1f}x** | {total_chunks} chunks"
746
+ tok_md = f"Tokens: {hare_tokens:,} processed | {saved:,} saved via state carry"
747
+ strat_md = (f"Right-ext: {right_extends} | Smart-merge: {smart_merges} | "
748
+ f"Full: {full_reencodes} | Merges: {merge_events}")
749
+ search_html = render_search(search_results, peak_scores)
750
+ state_html = render_state_viz(state_history, n_rwkv_layers)
751
+
752
+ yield grid_html, eff_md, tok_md, strat_md, search_html, state_html
753
+
754
+ if sleep_time > 0:
755
+ time.sleep(sleep_time)
756
+
757
+ eff = baseline_tokens / max(hare_tokens, 1)
758
+ total_chunks = sum(len(d['chunk_embs']) for d in encoder.span_data.values())
759
+ saved = baseline_tokens - hare_tokens
760
+ grid_html = render_grid(provider.received, provider.n_pieces)
761
+ eff_md = f"**Efficiency: {eff:.1f}x** | {total_chunks} chunks | COMPLETE"
762
+ tok_md = f"Tokens: {hare_tokens:,} processed | {saved:,} saved via state carry"
763
+ strat_md = (f"Right-ext: {right_extends} | Smart-merge: {smart_merges} | "
764
+ f"Full: {full_reencodes} | Merges: {merge_events}")
765
+
766
+ final_spans = provider.get_spans()
767
+ search_results = {}
768
+ for q in queries:
769
+ results = encoder.search(q_embs[q], final_spans, top_k=3)
770
+ search_results[q] = [
771
+ {'span': (s, e), 'score': sc, 'preview': pv,
772
+ 'n_chunks': nc, 'n_tokens': nt}
773
+ for s, e, sc, pv, nt, nc in results
774
+ ]
775
+ search_html = render_search(search_results, peak_scores)
776
+ state_html = render_state_viz(state_history, n_rwkv_layers)
777
+ yield grid_html, eff_md, tok_md, strat_md, search_html, state_html
778
+
779
+
780
+ @spaces.GPU
781
+ def start_demo(source_mode, demo_choice, url_input, queries_text, chunk_size):
782
+ model.cuda()
783
+ encoder = SpanEncoder(model, tokenizer, chunk_size=chunk_size)
784
+
785
+ if source_mode == "Quick Demo":
786
+ config = QUICK_DEMOS[demo_choice]
787
+ provider = TextProvider(config['text'],
788
+ piece_size=config['piece_size'], seed=42)
789
+ queries = config['queries']
790
+ sleep_time = config['sleep']
791
+ elif source_mode == "URL":
792
+ if not url_input:
793
+ yield ('<p style="color:#ffc107">Enter a URL to a text file.</p>',
794
+ '', '', '', '', '')
795
+ return
796
+ text = load_text(url=url_input)
797
+ provider = TextProvider(text, piece_size=4096, seed=42)
798
+ queries = [q.strip() for q in queries_text.split(',') if q.strip()]
799
+ sleep_time = 0
800
+ else:
801
+ return
802
+
803
+ if not queries:
804
+ queries = ["search query"]
805
+
806
+ q_embs = {q: encoder.encode_query(q) for q in queries}
807
+
808
+ yield from streaming_loop(provider, encoder, queries, q_embs, sleep_time)
809
+
810
+
811
+ def toggle_inputs(source_mode):
812
+ frankenstein_q = "on a dreary night the creature first opened its eyes, an innocent woman is wrongly executed, playing god with science"
813
+ return (
814
+ gr.update(visible=(source_mode == "Quick Demo")),
815
+ gr.update(visible=(source_mode == "URL")),
816
+ gr.update(visible=(source_mode != "Quick Demo"),
817
+ value=frankenstein_q),
818
+ )
819
+
820
+
821
+ def update_queries(demo_choice):
822
+ config = QUICK_DEMOS.get(demo_choice, {})
823
+ queries = config.get('queries', [])
824
+ return ', '.join(queries)
825
+
826
+
827
+ def build_demo():
828
+ with gr.Blocks(title="HARE Streaming Demo") as demo:
829
+ gr.Markdown(
830
+ "# HARE: Streaming Semantic Search",
831
+ )
832
+ gr.Markdown(
833
+ "Watch [HARE](https://huggingface.co/SixOpen/HARE) build a "
834
+ "semantic search index in real-time as content streams in "
835
+ "piece by piece. Unlike standard embedding models, HARE's "
836
+ "recurrent state carries forward full context without "
837
+ "re-encoding, allowing for search over live transcripts, "
838
+ "distributed content, and streaming files without "
839
+ "needing to download them in full.",
840
+ )
841
+
842
+ with gr.Row():
843
+ with gr.Column(scale=1, min_width=280):
844
+ source_mode = gr.Radio(
845
+ ["URL", "Quick Demo"],
846
+ value="URL",
847
+ label="Source",
848
+ )
849
+ demo_choice = gr.Dropdown(
850
+ list(QUICK_DEMOS.keys()),
851
+ value=list(QUICK_DEMOS.keys())[0],
852
+ label="Demo Content",
853
+ visible=False,
854
+ )
855
+ url_input = gr.Textbox(
856
+ label="Text URL",
857
+ value="https://gutenberg.org/files/84/84-0.txt",
858
+ placeholder="https://gutenberg.org/files/84/84-0.txt",
859
+ visible=True,
860
+ )
861
+ queries_input = gr.Textbox(
862
+ label="Search Queries (comma-separated)",
863
+ value="on a dreary night the creature first opened its eyes, an innocent woman is wrongly executed, playing god with science",
864
+ visible=True,
865
+ )
866
+
867
+ with gr.Accordion("Settings", open=False):
868
+ chunk_size = gr.Slider(
869
+ 128, 1024, value=512, step=64,
870
+ label="Chunk Size (tokens)",
871
+ )
872
+
873
+ start_btn = gr.Button("Start Demo", variant="primary", size="lg")
874
+
875
+ with gr.Column(scale=2):
876
+ gr.Markdown("### Download Progress")
877
+ piece_grid = gr.HTML(
878
+ '<div style="padding:20px;color:#666;text-align:center">'
879
+ 'Click "Start Demo" to begin</div>'
880
+ )
881
+
882
+ gr.Markdown("### Encoding Efficiency")
883
+ with gr.Row():
884
+ efficiency_md = gr.Markdown("**Efficiency: --**")
885
+ with gr.Row():
886
+ tokens_md = gr.Markdown("Tokens: --")
887
+ strategy_md = gr.Markdown("Right-ext: -- | Smart-merge: -- | Full: --")
888
+
889
+ gr.Markdown("### Search Results")
890
+ search_html = gr.HTML(
891
+ '<p style="color:#888">Results will appear here as '
892
+ 'pieces are processed...</p>'
893
+ )
894
+
895
+ gr.Markdown("### Recurrent State Evolution")
896
+ state_viz = gr.HTML(
897
+ '<p style="color:#888">State heatmap will appear as '
898
+ 'pieces are processed...</p>'
899
+ )
900
+
901
+ source_mode.change(
902
+ toggle_inputs,
903
+ inputs=[source_mode],
904
+ outputs=[demo_choice, url_input, queries_input],
905
+ )
906
+ demo_choice.change(
907
+ update_queries,
908
+ inputs=[demo_choice],
909
+ outputs=[queries_input],
910
+ )
911
+ start_btn.click(
912
+ start_demo,
913
+ inputs=[source_mode, demo_choice, url_input, queries_input,
914
+ chunk_size],
915
+ outputs=[piece_grid, efficiency_md, tokens_md, strategy_md,
916
+ search_html, state_viz],
917
+ )
918
+
919
+ return demo
920
+
921
+
922
+ demo = build_demo()
923
+ demo.queue().launch()