Spaces:
Paused
Paused
Update main.py
Browse files
main.py
CHANGED
|
@@ -8,7 +8,7 @@ from datetime import date, datetime
|
|
| 8 |
from google import genai
|
| 9 |
from google.genai import types
|
| 10 |
|
| 11 |
-
|
| 12 |
|
| 13 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 14 |
# CONFIG
|
|
@@ -381,27 +381,81 @@ def keyword_search():
|
|
| 381 |
rows = cur.fetchall()
|
| 382 |
return jsonify({"ok": True, "query": q, "data": rows})
|
| 383 |
|
|
|
|
| 384 |
@app.post("/api/similar")
|
| 385 |
def similar_search():
|
| 386 |
payload = request.get_json(force=True) or {}
|
| 387 |
text = (payload.get("text") or "").strip()
|
| 388 |
limit = max(1, min(int(payload.get("limit", 20)), 100))
|
|
|
|
|
|
|
|
|
|
| 389 |
if not text:
|
| 390 |
return jsonify({"ok": False, "error": "text required"}), 400
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 400 |
)
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 405 |
|
| 406 |
@app.get("/api/vocab")
|
| 407 |
def vocab():
|
|
|
|
| 8 |
from google import genai
|
| 9 |
from google.genai import types
|
| 10 |
|
| 11 |
+
from pymysql.err import OperationalErrorwarnings.filterwarnings("ignore")
|
| 12 |
|
| 13 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 14 |
# CONFIG
|
|
|
|
| 381 |
rows = cur.fetchall()
|
| 382 |
return jsonify({"ok": True, "query": q, "data": rows})
|
| 383 |
|
| 384 |
+
|
| 385 |
@app.post("/api/similar")
|
| 386 |
def similar_search():
|
| 387 |
payload = request.get_json(force=True) or {}
|
| 388 |
text = (payload.get("text") or "").strip()
|
| 389 |
limit = max(1, min(int(payload.get("limit", 20)), 100))
|
| 390 |
+
candidates = int(payload.get("candidates", max(200, limit * 10))) # ANN pre-topK
|
| 391 |
+
source_filter = (payload.get("source") or "").strip().upper() # e.g., "AIC"
|
| 392 |
+
|
| 393 |
if not text:
|
| 394 |
return jsonify({"ok": False, "error": "text required"}), 400
|
| 395 |
+
|
| 396 |
+
# Embed without NumPy path
|
| 397 |
+
try:
|
| 398 |
+
import torch
|
| 399 |
+
vec_t = _load_model().encode([text], batch_size=1, show_progress_bar=False, convert_to_tensor=True)
|
| 400 |
+
if isinstance(vec_t, torch.Tensor):
|
| 401 |
+
vec = vec_t[0].detach().cpu().tolist()
|
| 402 |
+
else:
|
| 403 |
+
vec = list(vec_t[0])
|
| 404 |
+
except Exception as e:
|
| 405 |
+
return jsonify({"ok": False, "error": f"embedding_unavailable: {e}"}), 503
|
| 406 |
+
|
| 407 |
+
vec_json = json.dumps(_pad(vec, VEC_DIM))
|
| 408 |
+
|
| 409 |
+
# Build query with explicit HNSW usage and staged join
|
| 410 |
+
where_src = "WHERE o.source = %s" if source_filter else ""
|
| 411 |
+
sql = f"""
|
| 412 |
+
WITH nn AS (
|
| 413 |
+
SELECT /*+ USE_INDEX(ps, hnsw_vec) */
|
| 414 |
+
ps.sent_id, ps.object_id, ps.seq, ps.sentence,
|
| 415 |
+
VEC_COSINE_DISTANCE(ps.embedding, CAST(%s AS VECTOR({VEC_DIM}))) AS distance
|
| 416 |
+
FROM provenance_sentences ps
|
| 417 |
+
ORDER BY distance
|
| 418 |
+
LIMIT %s
|
| 419 |
)
|
| 420 |
+
SELECT nn.object_id, nn.seq, nn.sentence, o.source, o.title, o.creator, nn.distance
|
| 421 |
+
FROM nn
|
| 422 |
+
JOIN objects o ON o.object_id = nn.object_id
|
| 423 |
+
{where_src}
|
| 424 |
+
ORDER BY nn.distance
|
| 425 |
+
LIMIT %s
|
| 426 |
+
"""
|
| 427 |
+
|
| 428 |
+
params = [vec_json, candidates]
|
| 429 |
+
if source_filter:
|
| 430 |
+
params.append(source_filter)
|
| 431 |
+
params.append(limit)
|
| 432 |
+
|
| 433 |
+
try:
|
| 434 |
+
with cursor() as cur:
|
| 435 |
+
cur.execute(sql, params)
|
| 436 |
+
rows = cur.fetchall()
|
| 437 |
+
return jsonify({"ok": True, "device": _DEVICE_INFO, "query": text, "data": rows,
|
| 438 |
+
"meta": {"limit": limit, "candidates": candidates, "source": source_filter or None}})
|
| 439 |
+
except OperationalError as e:
|
| 440 |
+
# TiDB OOM (1105) β retry with smaller candidate set automatically
|
| 441 |
+
if e.args and e.args[0] == 1105 and candidates > max(100, limit * 4):
|
| 442 |
+
smaller = max(100, limit * 4)
|
| 443 |
+
params2 = [vec_json, smaller]
|
| 444 |
+
if source_filter:
|
| 445 |
+
params2.append(source_filter)
|
| 446 |
+
params2.append(limit)
|
| 447 |
+
try:
|
| 448 |
+
with cursor() as cur:
|
| 449 |
+
cur.execute(sql, params2)
|
| 450 |
+
rows = cur.fetchall()
|
| 451 |
+
return jsonify({"ok": True, "device": _DEVICE_INFO, "query": text, "data": rows,
|
| 452 |
+
"meta": {"limit": limit, "candidates": smaller, "source": source_filter or None,
|
| 453 |
+
"note": "retried with smaller candidate set"}})
|
| 454 |
+
except Exception as e2:
|
| 455 |
+
return jsonify({"ok": False, "error": f"oom_retry_failed: {e2}"}), 500
|
| 456 |
+
# Not OOM or still failed
|
| 457 |
+
return jsonify({"ok": False, "error": f"query_failed: {e}"}), 500
|
| 458 |
+
|
| 459 |
|
| 460 |
@app.get("/api/vocab")
|
| 461 |
def vocab():
|