#!/usr/bin/env python3 import random import time import urllib.request import gradio as gr import spaces import torch import torch.nn.functional as F import triton import triton.language as tl from transformers import AutoModel, AutoTokenizer MODEL_ID = "SixOpen/HARE" model = AutoModel.from_pretrained(MODEL_ID, trust_remote_code=True).eval() tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) @triton.jit def _wkv7_fwd_kernel( R, K, V, DECAY, A, O, STATE_OUT, STATE_IN, sab_scale, T, stride_b, stride_t, stride_h, H: tl.constexpr, D: tl.constexpr, BLOCK_D: tl.constexpr, RETURN_STATE: tl.constexpr, HAS_INIT_STATE: tl.constexpr, ): pid = tl.program_id(0) b_idx = pid // H h_idx = pid % H base = b_idx * stride_b + h_idx * stride_h di = tl.arange(0, BLOCK_D) dj = tl.arange(0, BLOCK_D) mask_i = di < D mask_j = dj < D if HAS_INIT_STATE: s_off = b_idx * (H * D * D) + h_idx * (D * D) state_ptrs = STATE_IN + s_off + di[:, None] * D + dj[None, :] state_mask = mask_i[:, None] & mask_j[None, :] state = tl.load(state_ptrs, mask=state_mask, other=0.0).to(tl.float32) else: state = tl.zeros((BLOCK_D, BLOCK_D), dtype=tl.float32) for t in range(T): t_off = base + t * stride_t kt = tl.load(K + t_off + dj, mask=mask_j, other=0.0).to(tl.float32) vt = tl.load(V + t_off + di, mask=mask_i, other=0.0).to(tl.float32) rt = tl.load(R + t_off + dj, mask=mask_j, other=0.0).to(tl.float32) dt = tl.load(DECAY + t_off + dj, mask=mask_j, other=1.0).to(tl.float32) at = tl.load(A + t_off + dj, mask=mask_j, other=0.0).to(tl.float32) sa = tl.sum(state * (-kt)[None, :], axis=1) ka = kt * at sab = sa[:, None] * ka[None, :] state = state * dt[None, :] + sab_scale * sab + vt[:, None] * kt[None, :] state = tl.minimum(tl.maximum(state, -10.0), 10.0) out_t = tl.sum(state * rt[None, :], axis=1) tl.store(O + t_off + di, out_t, mask=mask_i) if RETURN_STATE: s_off = b_idx * (H * D * D) + h_idx * (D * D) state_ptrs = STATE_OUT + s_off + di[:, None] * D + dj[None, :] state_mask = mask_i[:, None] & mask_j[None, :] tl.store(state_ptrs, state, mask=state_mask) def wkv7_scan_triton(r, decay, k, v, a, sab_scale, return_state=False, init_state=None): B, T, H, D = r.shape r, k, v, decay, a = [x.contiguous() for x in (r, k, v, decay, a)] o = torch.empty_like(r) state_out = None if return_state: state_out = torch.empty(B, H, D, D, dtype=torch.float32, device=r.device) has_init = init_state is not None if has_init: init_state = init_state.contiguous().float() stride_b = T * H * D stride_t = H * D stride_h = D BLOCK_D = triton.next_power_of_2(D) _wkv7_fwd_kernel[(B * H,)]( r, k, v, decay, a, o, state_out, init_state, float(sab_scale), T, stride_b, stride_t, stride_h, H=H, D=D, BLOCK_D=BLOCK_D, RETURN_STATE=return_state, HAS_INIT_STATE=has_init, ) if return_state: return o, state_out return o def find_birwkv_layers(model): layers = [] ids = {} for m in model.modules(): if type(m).__name__ == 'BiRWKV7Layer': ids[id(m)] = len(layers) layers.append(m) return layers, ids class SpanEncoder: def __init__(self, model, tokenizer, chunk_size=512): self.model = model self.tokenizer = tokenizer self.device = next(model.parameters()).device self.chunk_size = chunk_size self.birwkv_layers, self.birwkv_ids = find_birwkv_layers(model) self._originals = {} self._hooked = False self._active_states = [None] * len(self.birwkv_layers) self.span_data = {} def _hook(self): if self._hooked: return for layer in self.birwkv_layers: self._originals[id(layer)] = layer.forward layer.forward = self._make_fwd(layer) self._hooked = True def _unhook(self): if not self._hooked: return for layer in self.birwkv_layers: layer.forward = self._originals[id(layer)] self._originals.clear() self._hooked = False def _make_fwd(self, layer): enc = self idx = self.birwkv_ids[id(layer)] def fwd(x, attention_mask=None, **kwargs): B, T, C_ = x.shape H, D = layer.num_heads, layer.head_size prev = enc._active_states[idx] if prev is not None: x_prev = torch.cat([prev['last_x'], x[:, :-1]], dim=1) else: x_prev = F.pad(x[:, :-1], (0, 0, 1, 0)) def mix(mu): return x + (x_prev - x) * torch.sigmoid(mu) r = layer.W_r(mix(layer.mu_r)).view(B, T, H, D) w = layer.W_w(mix(layer.mu_w)).view(B, T, H, D) k = layer.W_k(mix(layer.mu_k)).view(B, T, H, D) v = layer.W_v(mix(layer.mu_v)).view(B, T, H, D) a = layer.W_a(mix(layer.mu_a)).view(B, T, H, D) g = torch.sigmoid(layer.W_g(mix(layer.mu_g))) sab_scale = torch.sigmoid(layer.sab_gate) init_st = prev['wkv_state'] if prev else None r_f, k_f, v_f = r.float(), k.float() * (D ** -0.5), v.float() a_f = torch.sigmoid(a.float()) decay = torch.exp(-0.6065306597633104 * torch.sigmoid(w.float())) out_fwd, wkv_state = wkv7_scan_triton( r_f, decay, k_f, v_f, a_f, sab_scale, return_state=True, init_state=init_st) out_bwd = wkv7_scan_triton( r_f.flip(1), decay.flip(1), k_f.flip(1), v_f.flip(1), a_f.flip(1), sab_scale, return_state=False).flip(1) enc._active_states[idx] = { 'wkv_state': wkv_state, 'last_x': x[:, -1:].detach().clone(), } out = ((out_fwd + out_bwd) * 0.5).reshape(B, T, C_) out = layer.group_norm(out.transpose(1, 2)).transpose(1, 2) out = layer.W_o(out * g) return out, None return fwd @torch.no_grad() def _forward_encode_raw(self, text, init_states=None, max_length=8192): self._hook() if init_states is not None: self._active_states = [ {k: v.clone() for k, v in s.items()} if s else None for s in init_states ] else: self._active_states = [None] * len(self.birwkv_layers) enc = self.tokenizer(text, return_tensors='pt', truncation=True, max_length=max_length) ids = enc['input_ids'].to(self.device) mask = enc['attention_mask'].to(self.device) h = self.model(input_ids=ids, attention_mask=mask).last_hidden_state content = h[0, 1:-1, :].cpu() n_content = content.shape[0] final_states = [ {k: v.clone() for k, v in s.items()} if s else None for s in self._active_states ] self._unhook() return content, n_content, final_states def _chunk_hidden(self, content, return_residual=False): T = content.shape[0] chunks = [] last_end = 0 for start in range(0, T, self.chunk_size): end = min(start + self.chunk_size, T) if end - start < 32: break emb = F.normalize(content[start:end].mean(0, keepdim=True), p=2, dim=-1) chunks.append(emb) last_end = end if not chunks and T > 0: chunks.append(F.normalize(content.mean(0, keepdim=True), p=2, dim=-1)) last_end = T if return_residual: residual = content[last_end:] if last_end < T else None return chunks, residual return chunks @torch.no_grad() def encode_query(self, query): assert not self._hooked enc = self.tokenizer(query, return_tensors='pt', truncation=True, max_length=512) ids = enc['input_ids'].to(self.device) mask = enc['attention_mask'].to(self.device) h = self.model(input_ids=ids, attention_mask=mask).last_hidden_state m = mask.unsqueeze(-1).float() emb = (h * m).sum(1) / m.sum(1).clamp(min=1e-9) return F.normalize(emb, p=2, dim=-1).cpu() def encode_span(self, text, key): content, n_tok, states = self._forward_encode_raw(text) chunks, residual = self._chunk_hidden(content, return_residual=True) self.span_data[key] = { 'layer_states': states, 'chunk_embs': chunks, 'n_tokens': n_tok, 'residual_hidden': residual, } return n_tok def extend_right(self, piece_text, old_key, new_key): old = self.span_data.pop(old_key) content, n_new, states = self._forward_encode_raw( piece_text, init_states=old['layer_states']) if old.get('residual_hidden') is not None: content = torch.cat([old['residual_hidden'], content], dim=0) new_chunks, residual = self._chunk_hidden( content, return_residual=True) self.span_data[new_key] = { 'layer_states': states, 'chunk_embs': old['chunk_embs'] + new_chunks, 'n_tokens': old['n_tokens'] + n_new, 'residual_hidden': residual, } return n_new def smart_merge(self, new_text, left_key, new_key): left = self.span_data.pop(left_key) self.remove_old(new_key) content, n_new, states = self._forward_encode_raw( new_text, init_states=left['layer_states']) if left.get('residual_hidden') is not None: content = torch.cat([left['residual_hidden'], content], dim=0) new_chunks, residual = self._chunk_hidden( content, return_residual=True) self.span_data[new_key] = { 'layer_states': states, 'chunk_embs': left['chunk_embs'] + new_chunks, 'n_tokens': left['n_tokens'] + n_new, 'residual_hidden': residual, } return n_new def remove_old(self, new_key): s, e = new_key for old in list(self.span_data.keys()): if old[0] >= s and old[1] <= e: del self.span_data[old] def search(self, q_emb, spans, top_k=5): results = [] for s, e, text in spans: key = (s, e) data = self.span_data.get(key) if not data or not data['chunk_embs']: continue chunk_mat = torch.cat(data['chunk_embs'], dim=0) sims = (q_emb @ chunk_mat.T).squeeze(0) if sims.dim() == 0: sims = sims.unsqueeze(0) max_sim = sims.max().item() best_idx = sims.argmax().item() n_chunks = len(data['chunk_embs']) chars_per_chunk = len(text) // max(n_chunks, 1) offset = min(best_idx * chars_per_chunk, len(text) - 1) while offset > 0 and text[offset - 1] not in ' \n\t': offset -= 1 preview = text[offset:offset + 300].replace('\n', ' ').strip() results.append((s, e, max_sim, preview, data['n_tokens'], n_chunks)) results.sort(key=lambda x: x[2], reverse=True) return results[:top_k] class TextProvider: def __init__(self, text, piece_size=4096, seed=42): self.text = text self.piece_size = piece_size self.n_pieces = (len(text) + piece_size - 1) // piece_size self.received = [False] * self.n_pieces rng = random.Random(seed) self.arrival_order = list(range(self.n_pieces)) rng.shuffle(self.arrival_order) self.next_idx = 0 def poll_pieces(self): if self.next_idx >= self.n_pieces: return [] idx = self.arrival_order[self.next_idx] self.received[idx] = True self.next_idx += 1 return [idx] def get_spans(self): spans = [] i = 0 while i < self.n_pieces: if self.received[i]: j = i while j < self.n_pieces and self.received[j]: j += 1 s_byte = i * self.piece_size e_byte = min(j * self.piece_size, len(self.text)) spans.append((i, j, self.text[s_byte:e_byte])) i = j else: i += 1 return spans def piece_text(self, idx): s = idx * self.piece_size return self.text[s:min(s + self.piece_size, len(self.text))] def span_text(self, start_piece, end_piece): s = start_piece * self.piece_size e = min(end_piece * self.piece_size, len(self.text)) return self.text[s:e] def progress(self): return self.next_idx / self.n_pieces def is_complete(self): return self.next_idx >= self.n_pieces FRANKENSTEIN_EXCERPT = """\ I am by birth a Genevese; and my family is one of the most distinguished \ of that republic. My ancestors had been for many years counsellors and \ syndics; and my father had filled several public situations with honour \ and reputation. When I was thirteen years of age, we all went on a party of pleasure to \ the baths near Thonon: the inclemency of the weather obliged us to remain \ a day confined to the inn. In this house I found a volume of the works of \ Cornelius Agrippa. I opened it with apathy; the theory which he attempts \ to demonstrate, and the wonderful facts which he relates, soon changed \ this feeling into enthusiasm. A new light seemed to dawn upon my mind. When I returned home, my first care was to procure the whole works of \ this author. My father was not scientific, and I was left to struggle \ with a child's blindness, added to a student's thirst for knowledge. \ Under the guidance of my new preceptors, I entered with the greatest \ diligence into the search of the philosopher's stone and the elixir \ of life. What glory would attend the discovery, if I could banish \ disease from the human frame, and render man invulnerable to any but \ a violent death! It was on a dreary night of November that I beheld the accomplishment \ of my toils. With an anxiety that almost amounted to agony, I collected \ the instruments of life around me, that I might infuse a spark of being \ into the lifeless thing that lay at my feet. It was already one in the \ morning; the rain pattered dismally against the panes, and my candle was \ nearly burnt out, when, by the glimmer of the half-extinguished light, \ I saw the dull yellow eye of the creature open; it breathed hard, and \ a convulsive motion agitated its limbs. How can I describe my emotions at this catastrophe, or how delineate the \ wretch whom with such infinite pains and care I had endeavoured to form? \ I had selected his features as beautiful. Beautiful!--Great God! His \ yellow skin scarcely covered the work of muscles and arteries beneath; \ his hair was of a lustrous black, and flowing; his teeth of a pearly \ whiteness; but these luxuriances only formed a more horrid contrast with \ his watery eyes, that seemed almost of the same colour as the dun white \ sockets in which they were set, his shrivelled complexion, and straight \ black lips. I had worked hard for nearly two years, for the sole purpose of infusing \ life into an inanimate body. For this I had deprived myself of rest and \ health. I had desired it with an ardour that far exceeded moderation; but \ now that I had finished, the beauty of the dream vanished, and breathless \ horror and disgust filled my heart. I did not dare return to the apartment which I inhabited, but felt \ impelled to hurry on, although drenched by the rain which poured from a \ black and comfortless sky. I passed the night wretchedly. Morning, \ dismal and wet, at length dawned, and discovered to my sleepless and \ aching eyes the church of Ingolstadt, its white steeple and clock, \ which indicated the sixth hour. "I shall satiate my ardour for destruction," the creature said, "and \ make you so wretched that the light of day will be hateful to you. I \ will be with you on your wedding-night." I started forward, and \ exclaimed, "Villain! before you sign my death-warrant, be sure that \ you are yourself safe." My rage was without bounds; I would have seized \ him; but he eluded me, and quitted the house with precipitation. Great God! why did I not then expire! But I am a wretch, and none ever \ conceived of the horrors of my secret toil, whilst I dabbled among the \ unhallowed damps of the grave, or tortured the living animal to animate \ the lifeless clay. I was soon borne away by the waves, and lost in darkness and distance. \ Immense and rugged mountains of ice often barred up my passage, and I \ heard the thunder of the ground sea beneath. The cold is excessive, and \ many of my unfortunate comrades have already found a grave amidst this \ scene of desolation. Frankenstein! he is not here: I will not rest; I \ pursue him still over the untrodden snow and frozen ocean. """ QUICK_DEMOS = { "Frankenstein (excerpt)": { "text": FRANKENSTEIN_EXCERPT, "queries": [ "the creature opens its eyes for the first time", "playing god with science", "a threat on the wedding night", "a frozen arctic wasteland", ], "piece_size": 512, "sleep": 0.3, }, } def render_grid(received, n_pieces, highlight=None): max_width = 60 if n_pieces <= max_width: cells = [] for i in range(n_pieces): if i == highlight: bg = '#00ff41' elif received[i]: bg = '#28a745' else: bg = '#3a3a3a' cells.append( f'' ) else: cells = [] for col in range(max_width): s = col * n_pieces // max_width e = (col + 1) * n_pieces // max_width ratio = sum(received[s:e]) / max(1, e - s) hl = highlight is not None and s <= highlight < e if hl: bg = '#00ff41' elif ratio > 0.8: bg = '#28a745' elif ratio > 0.3: bg = '#17a2b8' elif ratio > 0: bg = '#6c757d' else: bg = '#3a3a3a' cells.append( f'' ) n_recv = sum(received) pct = n_recv / max(n_pieces, 1) * 100 grid = ''.join(cells) return ( f'
' f'
{grid}
' f'
' f'Piece {n_recv}/{n_pieces} ({pct:.0f}%)
' ) def render_search(results_dict, peak_scores=None): if not results_dict: return '

Waiting for data...

' def _score_color(score): if score > 0.5: return '#28a745' elif score > 0.4: return '#ffc107' return '#aaa' parts = [] for query, results in results_dict.items(): peak = peak_scores.get(query) if peak_scores else None header = f'"{query}"' if peak: header += (f' ' f'(peak: {peak["score"]:.3f})') parts.append( f'
' f'
' f'{header}
' ) cur_best = results[0]['score'] if results else 0 if peak and peak['score'] > cur_best + 0.01: psc = _score_color(peak['score']) pp = peak['preview'][:300].replace('<', '<').replace('>', '>') parts.append( f'
' f'{peak["score"]:.3f} ' f'peak
' f'{pp}...' f'
' ) if not results: parts.append('
No results yet
') else: for rank, r in enumerate(results[:3], 1): sc = _score_color(r['score']) preview = r['preview'][:300].replace('<', '<').replace('>', '>') parts.append( f'
' f'{r["score"]:.3f} ' f'[{r["span"][0]}-{r["span"][1]}]' f' ({r["n_chunks"]}ch)
' f'{preview}...' f'
' ) parts.append('
') return ''.join(parts) def _state_color(intensity): h = int(220 - intensity * 170) s = int(20 + intensity * 70) light = int(12 + intensity * 38) return f'hsl({h},{s}%,{light}%)' def render_state_viz(state_history, n_layers=14): if not state_history: return ('

Recurrent state evolution will appear ' 'as pieces are processed...

') n_steps = len(state_history) cell_w = max(4, min(14, 600 // max(n_steps, 1))) layer_maxes = [] for li in range(n_layers): vals = [state_history[t][li] for t in range(n_steps) if li < len(state_history[t])] layer_maxes.append(max(vals) if vals else 1.0) rows = [] for li in range(n_layers): cells = [] for t in range(n_steps): if li < len(state_history[t]): norm = state_history[t][li] intensity = min(norm / max(layer_maxes[li], 1e-6), 1.0) cells.append( f'') rows.append( f'
' f'R{li+1}' f'
{"".join(cells)}
' f'
') latest = state_history[-1] avg_norm = sum(latest) / len(latest) if latest else 0 most_active = 0 max_delta = 0 if len(state_history) >= 2: prev = state_history[-2] for li in range(min(len(latest), len(prev))): d = abs(latest[li] - prev[li]) if d > max_delta: max_delta = d most_active = li legend = ''.join( f'' for i in range(5)) return ( f'
' f'{"".join(rows)}' f'
' f'{n_layers} RWKV layers \u00d7 {n_steps} pieces | ' f'Avg state magnitude: {avg_norm:.1f}' f'{f" | Most active: R{most_active+1}" if len(state_history) >= 2 else ""}' f'
' f'
' f'{legend} low \u2192 high state magnitude' f'
') def load_text(url): resp = urllib.request.urlopen(url, timeout=30) text = resp.read().decode('utf-8', errors='replace') start = text.find('*** START OF') if start != -1: text = text[text.find('\n', start) + 1:] end = text.find('*** END OF') if end != -1: text = text[:end] return text def streaming_loop(provider, encoder, queries, q_embs, sleep_time=0): prev_span_keys = set() hare_tokens = 0 baseline_tokens = 0 right_extends = 0 smart_merges = 0 full_reencodes = 0 merge_events = 0 pieces_processed = 0 piece_queue = [] peak_scores = {} state_history = [] n_rwkv_layers = len(encoder.birwkv_layers) while not provider.is_complete(): new_pieces = provider.poll_pieces() if new_pieces: piece_queue.extend(new_pieces) random.shuffle(piece_queue) if not piece_queue: continue idx = piece_queue.pop(0) provider.received[idx] = True pieces_processed += 1 new_spans = provider.get_spans() new_keys = {(s, e) for s, e, _ in new_spans} for s, e, span_text_val in new_spans: key = (s, e) if key in prev_span_keys: continue right_key = (s, e - 1) if right_key in encoder.span_data: n = encoder.extend_right(provider.piece_text(e - 1), right_key, key) hare_tokens += n right_extends += 1 baseline_tokens += encoder.span_data[key]['n_tokens'] continue best_left = None for (os_, oe) in list(encoder.span_data.keys()): if os_ == s and oe < e: if best_left is None or oe > best_left[1]: best_left = (os_, oe) if best_left: new_portion = provider.span_text(best_left[1], e) n = encoder.smart_merge(new_portion, best_left, key) hare_tokens += n smart_merges += 1 baseline_tokens += encoder.span_data[key]['n_tokens'] continue encoder.remove_old(key) n = encoder.encode_span(span_text_val, key) hare_tokens += n full_reencodes += 1 baseline_tokens += n if len(new_keys) < len(prev_span_keys) and pieces_processed > 1: merge_events += 1 prev_span_keys = new_keys total_chunks = sum(len(d['chunk_embs']) for d in encoder.span_data.values()) eff = baseline_tokens / max(hare_tokens, 1) if encoder.span_data: largest_key = max(encoder.span_data.keys(), key=lambda k: k[1] - k[0]) states = encoder.span_data[largest_key].get('layer_states', []) norms = [] for st in states: if st is not None and 'wkv_state' in st: norms.append(st['wkv_state'].norm().item()) else: norms.append(0.0) state_history.append(norms) search_results = {} for q in queries: results = encoder.search(q_embs[q], new_spans, top_k=3) search_results[q] = [ {'span': (s, e), 'score': sc, 'preview': pv, 'n_chunks': nc, 'n_tokens': nt} for s, e, sc, pv, nt, nc in results ] if results: top = results[0] sc_top = top[2] if q not in peak_scores or sc_top > peak_scores[q]['score']: peak_scores[q] = {'score': sc_top, 'preview': top[3]} grid_html = render_grid(provider.received, provider.n_pieces, highlight=idx) saved = baseline_tokens - hare_tokens eff_md = f"**Efficiency: {eff:.1f}x** | {total_chunks} chunks" tok_md = f"Tokens: {hare_tokens:,} processed | {saved:,} saved via state carry" strat_md = (f"Right-ext: {right_extends} | Smart-merge: {smart_merges} | " f"Full: {full_reencodes} | Merges: {merge_events}") search_html = render_search(search_results, peak_scores) state_html = render_state_viz(state_history, n_rwkv_layers) yield grid_html, eff_md, tok_md, strat_md, search_html, state_html if sleep_time > 0: time.sleep(sleep_time) eff = baseline_tokens / max(hare_tokens, 1) total_chunks = sum(len(d['chunk_embs']) for d in encoder.span_data.values()) saved = baseline_tokens - hare_tokens grid_html = render_grid(provider.received, provider.n_pieces) eff_md = f"**Efficiency: {eff:.1f}x** | {total_chunks} chunks | COMPLETE" tok_md = f"Tokens: {hare_tokens:,} processed | {saved:,} saved via state carry" strat_md = (f"Right-ext: {right_extends} | Smart-merge: {smart_merges} | " f"Full: {full_reencodes} | Merges: {merge_events}") final_spans = provider.get_spans() search_results = {} for q in queries: results = encoder.search(q_embs[q], final_spans, top_k=3) search_results[q] = [ {'span': (s, e), 'score': sc, 'preview': pv, 'n_chunks': nc, 'n_tokens': nt} for s, e, sc, pv, nt, nc in results ] search_html = render_search(search_results, peak_scores) state_html = render_state_viz(state_history, n_rwkv_layers) yield grid_html, eff_md, tok_md, strat_md, search_html, state_html @spaces.GPU def start_demo(source_mode, demo_choice, url_input, queries_text, chunk_size): model.cuda() encoder = SpanEncoder(model, tokenizer, chunk_size=chunk_size) if source_mode == "Quick Demo": config = QUICK_DEMOS[demo_choice] provider = TextProvider(config['text'], piece_size=config['piece_size'], seed=42) queries = config['queries'] sleep_time = config['sleep'] elif source_mode == "URL": if not url_input: yield ('

Enter a URL to a text file.

', '', '', '', '', '') return text = load_text(url=url_input) provider = TextProvider(text, piece_size=4096, seed=42) queries = [q.strip() for q in queries_text.split(',') if q.strip()] sleep_time = 0 else: return if not queries: queries = ["search query"] q_embs = {q: encoder.encode_query(q) for q in queries} yield from streaming_loop(provider, encoder, queries, q_embs, sleep_time) def toggle_inputs(source_mode): frankenstein_q = "on a dreary night the creature first opened its eyes, an innocent woman is wrongly executed, playing god with science" return ( gr.update(visible=(source_mode == "Quick Demo")), gr.update(visible=(source_mode == "URL")), gr.update(visible=(source_mode != "Quick Demo"), value=frankenstein_q), ) def update_queries(demo_choice): config = QUICK_DEMOS.get(demo_choice, {}) queries = config.get('queries', []) return ', '.join(queries) def build_demo(): with gr.Blocks(title="HARE Streaming Demo") as demo: gr.Markdown( "# HARE: Streaming Semantic Search", ) gr.Markdown( "Watch [HARE](https://huggingface.co/SixOpen/HARE) build a " "semantic search index in real-time as content streams in " "piece by piece. Unlike standard embedding models, HARE's " "recurrent state carries forward full context without " "re-encoding, allowing for search over live transcripts, " "distributed content, and streaming files without " "needing to download them in full.", ) with gr.Row(): with gr.Column(scale=1, min_width=280): source_mode = gr.Radio( ["URL", "Quick Demo"], value="URL", label="Source", ) demo_choice = gr.Dropdown( list(QUICK_DEMOS.keys()), value=list(QUICK_DEMOS.keys())[0], label="Demo Content", visible=False, ) url_input = gr.Textbox( label="Text URL", value="https://gutenberg.org/files/84/84-0.txt", placeholder="https://gutenberg.org/files/84/84-0.txt", visible=True, ) queries_input = gr.Textbox( label="Search Queries (comma-separated)", value="on a dreary night the creature first opened its eyes, an innocent woman is wrongly executed, playing god with science", visible=True, ) with gr.Accordion("Settings", open=False): chunk_size = gr.Slider( 128, 1024, value=512, step=64, label="Chunk Size (tokens)", ) start_btn = gr.Button("Start Demo", variant="primary", size="lg") with gr.Column(scale=2): gr.Markdown("### Download Progress") piece_grid = gr.HTML( '
' 'Click "Start Demo" to begin
' ) gr.Markdown("### Encoding Efficiency") with gr.Row(): efficiency_md = gr.Markdown("**Efficiency: --**") with gr.Row(): tokens_md = gr.Markdown("Tokens: --") strategy_md = gr.Markdown("Right-ext: -- | Smart-merge: -- | Full: --") gr.Markdown("### Search Results") search_html = gr.HTML( '

Results will appear here as ' 'pieces are processed...

' ) gr.Markdown("### Recurrent State Evolution") state_viz = gr.HTML( '

State heatmap will appear as ' 'pieces are processed...

' ) source_mode.change( toggle_inputs, inputs=[source_mode], outputs=[demo_choice, url_input, queries_input], ) demo_choice.change( update_queries, inputs=[demo_choice], outputs=[queries_input], ) start_btn.click( start_demo, inputs=[source_mode, demo_choice, url_input, queries_input, chunk_size], outputs=[piece_grid, efficiency_md, tokens_md, strategy_md, search_html, state_viz], ) return demo demo = build_demo() demo.queue().launch()