import functools import json import os import textwrap from typing import List, Dict, Any import gradio as gr import requests import torch import torch.nn.functional as F from transformers import AutoTokenizer, AutoModel # ----------------------------- # Embedding utilities (from your snippet, with a couple of safety tweaks) # ----------------------------- def last_token_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0]) if left_padding: return last_hidden_states[:, -1] else: sequence_lengths = attention_mask.sum(dim=1) - 1 batch_size = last_hidden_states.shape[0] return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths] def get_detailed_instruct(task_description: str, query: str) -> str: return f"Instruct: {task_description}\nQuery: {query}" class Qwen3Embedding: def __init__(self, device: str, size: str = "0.6B"): assert size in ["0.6B", "4B", "8B"] model_id = "Qwen/Qwen3-Embedding-" + size self.device = device self.tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) # Use bfloat16 on GPU, float32 on CPU (safer on Spaces CPU) dtype = torch.bfloat16 if device != "cpu" else torch.float32 self.model = AutoModel.from_pretrained(model_id, trust_remote_code=True).to(device, dtype=dtype) self.prompt_query = ( "Given a natural language query, retrieve formal Rocq elements whose docstrings " "match the intent of the query." ) @torch.inference_mode() def generate(self, sentence: str, is_query: bool = False) -> torch.Tensor: input_text = get_detailed_instruct(self.prompt_query, sentence) if is_query else sentence batch_dict = self.tokenizer(input_text, padding=True, truncation=True, return_tensors="pt").to(self.device) outputs = self.model(**batch_dict) embeddings = last_token_pool(outputs.last_hidden_state, batch_dict["attention_mask"]) embeddings = F.normalize(embeddings, p=2, dim=1) return embeddings def name(self) -> str: return "qwen_embedding_base" @functools.lru_cache(maxsize=3) def get_embedder() -> Qwen3Embedding: return Qwen3Embedding(device="cpu", size="4B") # ----------------------------- # Backend call # ----------------------------- def call_retrieval_service( server_url: str, embedding: List[float], top_k: int, timeout: int = 60 ) -> List[Dict[str, Any]]: if server_url.endswith("/"): server_url = server_url[:-1] url = f"{server_url}/query" payload = {"query": [embedding], "top_k": int(top_k)} resp = requests.post(url, json=payload, timeout=timeout) resp.raise_for_status() data = resp.json() if not isinstance(data, list): raise ValueError("Unexpected response format: expected a list of entries.") return data # ----------------------------- # Formatting helpers # ----------------------------- def _html_escape(s: str) -> str: return ( s.replace("&", "&") .replace("<", "<") .replace(">", ">") ) def render_results(items: List[Dict[str, Any]]) -> str: if not items: return "
No results.
" rows = [] for idx, it in enumerate(items, start=1): score = it.get("score", 0.0) name = it.get("name", "") kind = it.get("kind", "") doc = it.get("docstring", "") or "" location = it.get("location", "") or "" # Trim long docstrings for the summary line summary = " ".join(doc.strip().split()) if len(summary) > 240: summary = summary[:240].rstrip() + "…" block = f"""
#{idx} {_html_escape(name)} [{_html_escape(kind)}] · score={score:.4f}
in {_html_escape(location)}
{_html_escape(summary or "(no docstring)")}
{_html_escape(doc)}
""" rows.append(block) style = """ """ return style + "\n".join(rows) # ----------------------------- # Gradio app # ----------------------------- DEFAULT_SERVER = os.environ.get("COSIM_SERVER_URL", "https://theostos-llm4docq-cosim.hf.space") def search( query: str, top_k: int, server_url: str, show_raw: bool, ) -> List[Any]: query = (query or "").strip() if not query: return [gr.update(value="
Please enter a query.
"), None] try: embedder = get_embedder() with torch.inference_mode(): emb = embedder.generate(query, is_query=True) # Convert to plain list[float] emb_list = emb[0].detach().to(torch.float32).cpu().tolist() items = call_retrieval_service(server_url, emb_list, top_k) html = render_results(items) if show_raw: return [html, items] else: return [html, None] except requests.exceptions.RequestException as e: msg = f"
Request error: {_html_escape(str(e))}
" return [msg, None] except RuntimeError as e: msg = f"
Runtime error: {_html_escape(str(e))}
" return [msg, None] except Exception as e: msg = f"
Unexpected error: {_html_escape(str(e))}
" return [msg, None] with gr.Blocks(title="MathComp Retrieval (Qwen3 Embedding 4B)", theme=gr.themes.Soft()) as demo: gr.Markdown("# 🔎 MathComp Retrieval") status_md = gr.Markdown("⏳ Loading model… (first time may take a bit)") def warmup(): try: _ = get_embedder() # safe default return "✅ Model ready." except Exception as e: return f"⚠️ Warmup failed: {e}" demo.load(fn=warmup, inputs=None, outputs=status_md) with gr.Row(): with gr.Column(scale=3): query = gr.Textbox( label="Query", placeholder="e.g., commutative group morphisms", lines=3, autofocus=True, ) with gr.Row(): top_k = gr.Slider(1, 50, value=5, step=1, label="top_k") with gr.Accordion("Advanced", open=False): server_url = gr.Textbox(value=DEFAULT_SERVER, label="Retrieval server URL") show_raw = gr.Checkbox(value=False, label="Also show raw JSON response") with gr.Row(): run_btn = gr.Button("Search", variant="primary") clear_btn = gr.Button("Clear") with gr.Column(scale=4): pretty = gr.HTML(label="Results") raw_json = gr.JSON(label="Raw JSON", visible=False) def on_toggle_raw(show: bool): return gr.update(visible=show) show_raw.change(fn=on_toggle_raw, inputs=show_raw, outputs=raw_json) run_btn.click( fn=search, inputs=[query, top_k, server_url, show_raw], outputs=[pretty, raw_json], api_name="search", ) clear_btn.click(lambda: ("", 5, "0.6B", True, DEFAULT_SERVER, False, "
", None), inputs=None, outputs=[query, top_k, server_url, show_raw, pretty, raw_json]) gr.Examples( examples=[ ["polynomial division lemma for ringType"], ["matrix rank properties over finite fields"], ["group homomorphism kernel characterization"], ["bigop lemmas about summation reindexing"], ], inputs=[query], label="Try these", ) if __name__ == "__main__": demo.launch()