Spaces:
Sleeping
Sleeping
| import functools | |
| import json | |
| import os | |
| import textwrap | |
| from typing import List, Dict, Any | |
| import gradio as gr | |
| import requests | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers import AutoTokenizer, AutoModel | |
| # ----------------------------- | |
| # Embedding utilities (from your snippet, with a couple of safety tweaks) | |
| # ----------------------------- | |
| def last_token_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: | |
| left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) | |
| if left_padding: | |
| return last_hidden_states[:, -1] | |
| else: | |
| sequence_lengths = attention_mask.sum(dim=1) - 1 | |
| batch_size = last_hidden_states.shape[0] | |
| return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths] | |
| def get_detailed_instruct(task_description: str, query: str) -> str: | |
| return f"Instruct: {task_description}\nQuery: {query}" | |
| class Qwen3Embedding: | |
| def __init__(self, device: str, size: str = "0.6B"): | |
| assert size in ["0.6B", "4B", "8B"] | |
| model_id = "Qwen/Qwen3-Embedding-" + size | |
| self.device = device | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) | |
| # Use bfloat16 on GPU, float32 on CPU (safer on Spaces CPU) | |
| dtype = torch.bfloat16 if device != "cpu" else torch.float32 | |
| self.model = AutoModel.from_pretrained(model_id, trust_remote_code=True).to(device, dtype=dtype) | |
| self.prompt_query = ( | |
| "Given a natural language query, retrieve formal Rocq elements whose docstrings " | |
| "match the intent of the query." | |
| ) | |
| def generate(self, sentence: str, is_query: bool = False) -> torch.Tensor: | |
| input_text = get_detailed_instruct(self.prompt_query, sentence) if is_query else sentence | |
| batch_dict = self.tokenizer(input_text, padding=True, truncation=True, return_tensors="pt").to(self.device) | |
| outputs = self.model(**batch_dict) | |
| embeddings = last_token_pool(outputs.last_hidden_state, batch_dict["attention_mask"]) | |
| embeddings = F.normalize(embeddings, p=2, dim=1) | |
| return embeddings | |
| def name(self) -> str: | |
| return "qwen_embedding_base" | |
| def get_embedder() -> Qwen3Embedding: | |
| return Qwen3Embedding(device="cpu", size="4B") | |
| # ----------------------------- | |
| # Backend call | |
| # ----------------------------- | |
| def call_retrieval_service( | |
| server_url: str, embedding: List[float], top_k: int, timeout: int = 60 | |
| ) -> List[Dict[str, Any]]: | |
| if server_url.endswith("/"): | |
| server_url = server_url[:-1] | |
| url = f"{server_url}/query" | |
| payload = {"query": [embedding], "top_k": int(top_k)} | |
| resp = requests.post(url, json=payload, timeout=timeout) | |
| resp.raise_for_status() | |
| data = resp.json() | |
| if not isinstance(data, list): | |
| raise ValueError("Unexpected response format: expected a list of entries.") | |
| return data | |
| # ----------------------------- | |
| # Formatting helpers | |
| # ----------------------------- | |
| def _html_escape(s: str) -> str: | |
| return ( | |
| s.replace("&", "&") | |
| .replace("<", "<") | |
| .replace(">", ">") | |
| ) | |
| def render_results(items: List[Dict[str, Any]]) -> str: | |
| if not items: | |
| return "<div>No results.</div>" | |
| rows = [] | |
| for idx, it in enumerate(items, start=1): | |
| score = it.get("score", 0.0) | |
| name = it.get("name", "") | |
| kind = it.get("kind", "") | |
| doc = it.get("docstring", "") or "" | |
| location = it.get("location", "") or "" | |
| # Trim long docstrings for the summary line | |
| summary = " ".join(doc.strip().split()) | |
| if len(summary) > 240: | |
| summary = summary[:240].rstrip() + "…" | |
| block = f""" | |
| <div class="result-card"> | |
| <div class="header"> | |
| <span class="rank">#{idx}</span> | |
| <code class="name">{_html_escape(name)}</code> | |
| <span class="meta">[{_html_escape(kind)}] · score={score:.4f}</span> | |
| </div> | |
| <div class="location">in {_html_escape(location)}</div> | |
| <details class="doc"> | |
| <summary>{_html_escape(summary or "(no docstring)")}</summary> | |
| <pre>{_html_escape(doc)}</pre> | |
| </details> | |
| </div> | |
| """ | |
| rows.append(block) | |
| style = """ | |
| <style> | |
| .result-card {border: 1px solid rgba(0,0,0,.08); padding: 12px 14px; border-radius: 12px; margin-bottom: 12px;} | |
| .header {display:flex; gap:10px; align-items:center; flex-wrap:wrap;} | |
| .rank {font-weight: 700;} | |
| .name {font-size: 0.95rem; background: rgba(0,0,0,.03); padding: 2px 6px; border-radius: 6px;} | |
| .meta {opacity: 0.7;} | |
| .location {font-size: 0.9rem; opacity: 0.8; margin: 4px 0 8px;} | |
| details.doc summary {cursor: pointer; font-weight: 500; margin-bottom: 6px;} | |
| details.doc pre {white-space: pre-wrap; background: rgba(0,0,0,.02); padding: 10px; border-radius: 8px;} | |
| </style> | |
| """ | |
| return style + "\n".join(rows) | |
| # ----------------------------- | |
| # Gradio app | |
| # ----------------------------- | |
| DEFAULT_SERVER = os.environ.get("COSIM_SERVER_URL", "https://theostos-llm4docq-cosim.hf.space") | |
| def search( | |
| query: str, | |
| top_k: int, | |
| server_url: str, | |
| show_raw: bool, | |
| ) -> List[Any]: | |
| query = (query or "").strip() | |
| if not query: | |
| return [gr.update(value="<div>Please enter a query.</div>"), None] | |
| try: | |
| embedder = get_embedder() | |
| with torch.inference_mode(): | |
| emb = embedder.generate(query, is_query=True) | |
| # Convert to plain list[float] | |
| emb_list = emb[0].detach().to(torch.float32).cpu().tolist() | |
| items = call_retrieval_service(server_url, emb_list, top_k) | |
| html = render_results(items) | |
| if show_raw: | |
| return [html, items] | |
| else: | |
| return [html, None] | |
| except requests.exceptions.RequestException as e: | |
| msg = f"<div style='color:#b00020'>Request error: {_html_escape(str(e))}</div>" | |
| return [msg, None] | |
| except RuntimeError as e: | |
| msg = f"<div style='color:#b00020'>Runtime error: {_html_escape(str(e))}</div>" | |
| return [msg, None] | |
| except Exception as e: | |
| msg = f"<div style='color:#b00020'>Unexpected error: {_html_escape(str(e))}</div>" | |
| return [msg, None] | |
| with gr.Blocks(title="MathComp Retrieval (Qwen3 Embedding 4B)", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# 🔎 MathComp Retrieval") | |
| status_md = gr.Markdown("⏳ Loading model… (first time may take a bit)") | |
| def warmup(): | |
| try: | |
| _ = get_embedder() # safe default | |
| return "✅ Model ready." | |
| except Exception as e: | |
| return f"⚠️ Warmup failed: {e}" | |
| demo.load(fn=warmup, inputs=None, outputs=status_md) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| query = gr.Textbox( | |
| label="Query", | |
| placeholder="e.g., commutative group morphisms", | |
| lines=3, | |
| autofocus=True, | |
| ) | |
| with gr.Row(): | |
| top_k = gr.Slider(1, 50, value=5, step=1, label="top_k") | |
| with gr.Accordion("Advanced", open=False): | |
| server_url = gr.Textbox(value=DEFAULT_SERVER, label="Retrieval server URL") | |
| show_raw = gr.Checkbox(value=False, label="Also show raw JSON response") | |
| with gr.Row(): | |
| run_btn = gr.Button("Search", variant="primary") | |
| clear_btn = gr.Button("Clear") | |
| with gr.Column(scale=4): | |
| pretty = gr.HTML(label="Results") | |
| raw_json = gr.JSON(label="Raw JSON", visible=False) | |
| def on_toggle_raw(show: bool): | |
| return gr.update(visible=show) | |
| show_raw.change(fn=on_toggle_raw, inputs=show_raw, outputs=raw_json) | |
| run_btn.click( | |
| fn=search, | |
| inputs=[query, top_k, server_url, show_raw], | |
| outputs=[pretty, raw_json], | |
| api_name="search", | |
| ) | |
| clear_btn.click(lambda: ("", 5, "0.6B", True, DEFAULT_SERVER, False, "<div/>", None), | |
| inputs=None, | |
| outputs=[query, top_k, server_url, show_raw, pretty, raw_json]) | |
| gr.Examples( | |
| examples=[ | |
| ["polynomial division lemma for ringType"], | |
| ["matrix rank properties over finite fields"], | |
| ["group homomorphism kernel characterization"], | |
| ["bigop lemmas about summation reindexing"], | |
| ], | |
| inputs=[query], | |
| label="Try these", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |