theostos's picture
fix warmup
d6aba3e
import functools
import json
import os
import textwrap
from typing import List, Dict, Any
import gradio as gr
import requests
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
# -----------------------------
# Embedding utilities (from your snippet, with a couple of safety tweaks)
# -----------------------------
def last_token_pool(last_hidden_states: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
if left_padding:
return last_hidden_states[:, -1]
else:
sequence_lengths = attention_mask.sum(dim=1) - 1
batch_size = last_hidden_states.shape[0]
return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
def get_detailed_instruct(task_description: str, query: str) -> str:
return f"Instruct: {task_description}\nQuery: {query}"
class Qwen3Embedding:
def __init__(self, device: str, size: str = "0.6B"):
assert size in ["0.6B", "4B", "8B"]
model_id = "Qwen/Qwen3-Embedding-" + size
self.device = device
self.tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
# Use bfloat16 on GPU, float32 on CPU (safer on Spaces CPU)
dtype = torch.bfloat16 if device != "cpu" else torch.float32
self.model = AutoModel.from_pretrained(model_id, trust_remote_code=True).to(device, dtype=dtype)
self.prompt_query = (
"Given a natural language query, retrieve formal Rocq elements whose docstrings "
"match the intent of the query."
)
@torch.inference_mode()
def generate(self, sentence: str, is_query: bool = False) -> torch.Tensor:
input_text = get_detailed_instruct(self.prompt_query, sentence) if is_query else sentence
batch_dict = self.tokenizer(input_text, padding=True, truncation=True, return_tensors="pt").to(self.device)
outputs = self.model(**batch_dict)
embeddings = last_token_pool(outputs.last_hidden_state, batch_dict["attention_mask"])
embeddings = F.normalize(embeddings, p=2, dim=1)
return embeddings
def name(self) -> str:
return "qwen_embedding_base"
@functools.lru_cache(maxsize=3)
def get_embedder() -> Qwen3Embedding:
return Qwen3Embedding(device="cpu", size="4B")
# -----------------------------
# Backend call
# -----------------------------
def call_retrieval_service(
server_url: str, embedding: List[float], top_k: int, timeout: int = 60
) -> List[Dict[str, Any]]:
if server_url.endswith("/"):
server_url = server_url[:-1]
url = f"{server_url}/query"
payload = {"query": [embedding], "top_k": int(top_k)}
resp = requests.post(url, json=payload, timeout=timeout)
resp.raise_for_status()
data = resp.json()
if not isinstance(data, list):
raise ValueError("Unexpected response format: expected a list of entries.")
return data
# -----------------------------
# Formatting helpers
# -----------------------------
def _html_escape(s: str) -> str:
return (
s.replace("&", "&")
.replace("<", "&lt;")
.replace(">", "&gt;")
)
def render_results(items: List[Dict[str, Any]]) -> str:
if not items:
return "<div>No results.</div>"
rows = []
for idx, it in enumerate(items, start=1):
score = it.get("score", 0.0)
name = it.get("name", "")
kind = it.get("kind", "")
doc = it.get("docstring", "") or ""
location = it.get("location", "") or ""
# Trim long docstrings for the summary line
summary = " ".join(doc.strip().split())
if len(summary) > 240:
summary = summary[:240].rstrip() + "…"
block = f"""
<div class="result-card">
<div class="header">
<span class="rank">#{idx}</span>
<code class="name">{_html_escape(name)}</code>
<span class="meta">[{_html_escape(kind)}] · score={score:.4f}</span>
</div>
<div class="location">in {_html_escape(location)}</div>
<details class="doc">
<summary>{_html_escape(summary or "(no docstring)")}</summary>
<pre>{_html_escape(doc)}</pre>
</details>
</div>
"""
rows.append(block)
style = """
<style>
.result-card {border: 1px solid rgba(0,0,0,.08); padding: 12px 14px; border-radius: 12px; margin-bottom: 12px;}
.header {display:flex; gap:10px; align-items:center; flex-wrap:wrap;}
.rank {font-weight: 700;}
.name {font-size: 0.95rem; background: rgba(0,0,0,.03); padding: 2px 6px; border-radius: 6px;}
.meta {opacity: 0.7;}
.location {font-size: 0.9rem; opacity: 0.8; margin: 4px 0 8px;}
details.doc summary {cursor: pointer; font-weight: 500; margin-bottom: 6px;}
details.doc pre {white-space: pre-wrap; background: rgba(0,0,0,.02); padding: 10px; border-radius: 8px;}
</style>
"""
return style + "\n".join(rows)
# -----------------------------
# Gradio app
# -----------------------------
DEFAULT_SERVER = os.environ.get("COSIM_SERVER_URL", "https://theostos-llm4docq-cosim.hf.space")
def search(
query: str,
top_k: int,
server_url: str,
show_raw: bool,
) -> List[Any]:
query = (query or "").strip()
if not query:
return [gr.update(value="<div>Please enter a query.</div>"), None]
try:
embedder = get_embedder()
with torch.inference_mode():
emb = embedder.generate(query, is_query=True)
# Convert to plain list[float]
emb_list = emb[0].detach().to(torch.float32).cpu().tolist()
items = call_retrieval_service(server_url, emb_list, top_k)
html = render_results(items)
if show_raw:
return [html, items]
else:
return [html, None]
except requests.exceptions.RequestException as e:
msg = f"<div style='color:#b00020'>Request error: {_html_escape(str(e))}</div>"
return [msg, None]
except RuntimeError as e:
msg = f"<div style='color:#b00020'>Runtime error: {_html_escape(str(e))}</div>"
return [msg, None]
except Exception as e:
msg = f"<div style='color:#b00020'>Unexpected error: {_html_escape(str(e))}</div>"
return [msg, None]
with gr.Blocks(title="MathComp Retrieval (Qwen3 Embedding 4B)", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🔎 MathComp Retrieval")
status_md = gr.Markdown("⏳ Loading model… (first time may take a bit)")
def warmup():
try:
_ = get_embedder() # safe default
return "✅ Model ready."
except Exception as e:
return f"⚠️ Warmup failed: {e}"
demo.load(fn=warmup, inputs=None, outputs=status_md)
with gr.Row():
with gr.Column(scale=3):
query = gr.Textbox(
label="Query",
placeholder="e.g., commutative group morphisms",
lines=3,
autofocus=True,
)
with gr.Row():
top_k = gr.Slider(1, 50, value=5, step=1, label="top_k")
with gr.Accordion("Advanced", open=False):
server_url = gr.Textbox(value=DEFAULT_SERVER, label="Retrieval server URL")
show_raw = gr.Checkbox(value=False, label="Also show raw JSON response")
with gr.Row():
run_btn = gr.Button("Search", variant="primary")
clear_btn = gr.Button("Clear")
with gr.Column(scale=4):
pretty = gr.HTML(label="Results")
raw_json = gr.JSON(label="Raw JSON", visible=False)
def on_toggle_raw(show: bool):
return gr.update(visible=show)
show_raw.change(fn=on_toggle_raw, inputs=show_raw, outputs=raw_json)
run_btn.click(
fn=search,
inputs=[query, top_k, server_url, show_raw],
outputs=[pretty, raw_json],
api_name="search",
)
clear_btn.click(lambda: ("", 5, "0.6B", True, DEFAULT_SERVER, False, "<div/>", None),
inputs=None,
outputs=[query, top_k, server_url, show_raw, pretty, raw_json])
gr.Examples(
examples=[
["polynomial division lemma for ringType"],
["matrix rank properties over finite fields"],
["group homomorphism kernel characterization"],
["bigop lemmas about summation reindexing"],
],
inputs=[query],
label="Try these",
)
if __name__ == "__main__":
demo.launch()