RikkaBotan's picture
Update app.py
0fdbb47 verified
import gradio as gr
import torch
from sentence_transformers import SentenceTransformer
from ddgs import DDGS
import time
import numpy as np
# Load Model
model = SentenceTransformer(
"RikkaBotan/stable-static-embedding-fast-retrieval-mrl-ja",
trust_remote_code=True,
device="cuda" if torch.cuda.is_available() else "cpu"
)
# Web Search with error handling
def web_search(query, max_results=100):
results = []
with DDGS() as ddgs:
try:
for i, r in enumerate(ddgs.text(query, max_results=max_results), start=1):
try:
results.append({
"title": r.get("title", ""),
"body": r.get("body", ""),
"href": r.get("href", "")
})
except Exception as e:
print(f"Skip doc {i}: {e}")
except Exception as e:
print(f"Skip web batch (max={max_results}): {e}")
return results
# Standard Semantic Search
def semantic_web_search(query):
if query.strip() == "":
return "Please enter a search query."
docs = web_search(query, max_results=100)
texts = [d["title"] + " " + d["body"] for d in docs]
with torch.no_grad():
embeddings = model.encode(
[query] + texts[:256],
convert_to_tensor=True,
normalize_embeddings=True
)
query_emb = embeddings[0]
doc_embs = embeddings[1:]
scores = (query_emb @ doc_embs.T).cpu().numpy()
ranked = sorted(zip(scores, docs), key=lambda x: x[0], reverse=True)[:30]
md = ""
for i, (score, d) in enumerate(ranked):
md += f"""
#### 💎 Rank {i+1}
[{d['title']}]({d['href']})
**Score:** `{score:.4f}`
{d['body']}
---
"""
return md
def progressive_search(query, threshold=0.7, step=50, max_cap=999):
if query.strip() == "":
yield "Please enter a search query."
return
current_k = step
scores_last = []
docs_last = []
seen_urls = set()
total_examined = 0
while current_k <= max_cap:
try:
docs = web_search(query, max_results=current_k)
except Exception as e:
yield f"Skipped batch {current_k} due to error: {e}"
current_k += step
continue
if len(docs) == 0:
yield f"No documents fetched for {current_k} results"
current_k += step
continue
total_examined += len(docs)
new_docs = []
for d in docs:
url = d["href"]
if url not in seen_urls:
seen_urls.add(url)
new_docs.append(d)
if len(new_docs) == 0:
current_k += step
continue
texts = [d["title"] + " " + d["body"] for d in new_docs]
with torch.no_grad():
embeddings = model.encode(
[query] + texts[:256],
convert_to_tensor=True,
normalize_embeddings=True
)
query_emb = embeddings[0]
doc_embs = embeddings[1:]
scores = (query_emb @ doc_embs.T).cpu().numpy().flatten()
scores_last.extend(scores.tolist())
docs_last.extend(new_docs)
best_score = float(np.max(scores_last))
md = (
f"### Searching…\n"
f"- Documents examined (with duplicates): `{total_examined}`\n"
f"- Best score so far: `{best_score:.4f}`\n"
)
yield md
if best_score >= threshold:
ranked = sorted(
zip(scores_last, docs_last),
key=lambda x: x[0],
reverse=True
)[:5]
md = "### Threshold reached!\n"
for i, (score, d) in enumerate(ranked):
md += f"""
#### Rank {i+1}
[{d['title']}]({d['href']})
**Score:** `{score:.4f}`
{d['body']}
---
"""
yield md
return
current_k += step
time.sleep(1)
ranked = sorted(
zip(scores_last, docs_last),
key=lambda x: x[0],
reverse=True
)[:5]
md = "### Threshold not reached in max search range.\n"
for i, (score, d) in enumerate(ranked):
md += f"""
#### Rank {i+1}
[{d['title']}]({d['href']})
**Score:** `{score:.4f}`
{d['body']}
---
"""
yield md
# UI
pastel_css = """
body {
background: linear-gradient(180deg, #f5f9ff 0%, #eaf3ff 40%, #dbeafe 100%);
}
/* gradient headings */
h1, h2, h3, h4 {
background: linear-gradient(135deg, #0b1f5e 0%, #1e3a8a 15%, #3b82f6 30%, #93c5fd 100%);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
font-weight: 800;
letter-spacing: 0.4px;
padding: 4px;
}
/* optional: slightly softer subtitle tone */
h2, h3 {
opacity: 0.9;
}
.gradio-container {
font-family: 'Helvetica Neue', sans-serif;
color: #1e3a8a;
}
/* model card */
.model-card {
background: #ffffff;
border-radius: 18px;
padding: 22px;
border: 1px solid #dbeafe;
box-shadow: 0 12px 20px rgba(60,120,255,0.18);
margin-bottom: 20px;
}
/* result card */
.result-card {
background: #ffffff;
border-radius: 18px;
padding: 22px;
border: 1px solid #dbeafe;
box-shadow: 0 12px 20px rgba(60,120,255,0.18);
}
.gr-markdown, .prose {
border: none !important;
box-shadow: none !important;
padding: 0 !important;
color: #1e3a8a !important;
}
.model-card, .result-card {
background: #ffffff;
color: #1e3a8a;
}
@media (prefers-color-scheme: dark) {
body {
background: linear-gradient(180deg, #0f172a 0%, #1e293b 40%, #334155 100%);
}
.gradio-container {
color: #dbeafe;
}
.gr-markdown, .prose {
color: #dbeafe !important;
}
.model-card, .result-card {
background: #1a1a1a;
color: #dbeafe;
border: 1px solid #3b82f6;
box-shadow: 0 12px 20px rgba(60,120,255,0.18);
}
.gr-markdown, .prose {
color: #dbeafe !important;
}
}
textarea, input {
border-radius: 12px !important;
border: 1px solid #c7ddff !important;
background-color: #f5f9ff !important;
color: #1e3a8a !important;
}
button {
background: linear-gradient(135deg, #1e3a8a 0%, #3b82f6 40%, #93c5fd 100%) !important;
color: #ffffff !important;
border-radius: 14px !important;
border: 1px solid #93c5fd !important;
font-weight: 600;
letter-spacing: 0.3px;
box-shadow:
0 6px 14px rgba(60,120,255,0.28),
inset 0 1px 0 rgba(255,255,255,0.6);
transition: all 0.25s ease;
}
button:hover {
background: linear-gradient(135deg, #1b3380 0%, #2563eb 40%, #7fb8ff 100%) !important;
box-shadow:
0 8px 18px rgba(60,120,255,0.35),
inset 0 1px 0 rgba(255,255,255,0.7);
transform: translateY(-1px);
}
button:active {
transform: translateY(1px);
box-shadow:
0 3px 8px rgba(60,120,255,0.2),
inset 0 2px 4px rgba(0,0,0,0.08);
}
"""
with gr.Blocks(css=pastel_css) as demo:
gr.Markdown('# Semantic Web Search and Deep Web Search')
gr.Markdown('## Fast Retrieval with Stable Static Embedding')
with gr.Column(elem_classes="model-card"):
gr.Markdown("""
## 使用モデル
**[RikkaBotan/stable-static-embedding-fast-retrieval-mrl-ja](https://huggingface.co/RikkaBotan/stable-static-embedding-fast-retrieval-mrl-ja)**
### 性能
* **NanoBEIR_ja において NDCG@10 = 0.4507 を達成**
* 他の静的埋め込みモデルよりも高い性能
### 効率性
* 512次元
* 約2倍高速な検索
* Separable Dynamic Tanh を採用
""")
with gr.Tabs():
# Standard
with gr.Tab("Standard Search"):
query1 = gr.Textbox(
value="安定性静的埋め込みモデルとは何ですか?",
label="検索クエリを入力してください。"
)
btn1 = gr.Button("Search")
with gr.Column(elem_classes="result-card"):
out1 = gr.Markdown()
btn1.click(
semantic_web_search,
inputs=query1,
outputs=out1,
)
# deep
with gr.Tab("Deep Search"):
query2 = gr.Textbox(
value="安定性静的埋め込みモデルとは何ですか?",
label="検索クエリを入力してください。"
)
threshold = gr.Slider(
0.3, 0.95, value=0.7, step=0.05,
label="Score Threshold"
)
btn2 = gr.Button("Run Deep Search")
with gr.Column(elem_classes="result-card"):
out2 = gr.Markdown()
btn2.click(
progressive_search,
inputs=[query2, threshold],
outputs=out2,
show_progress=True,
)
gr.Markdown("© 2026 Rikka Botan")
demo.launch()