rairo commited on
Commit
08c1d78
Β·
verified Β·
1 Parent(s): c24393d

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +68 -14
main.py CHANGED
@@ -8,7 +8,7 @@ from datetime import date, datetime
8
  from google import genai
9
  from google.genai import types
10
 
11
- warnings.filterwarnings("ignore")
12
 
13
  # ───────────────────────────────────────────────────────────────────────────────
14
  # CONFIG
@@ -381,27 +381,81 @@ def keyword_search():
381
  rows = cur.fetchall()
382
  return jsonify({"ok": True, "query": q, "data": rows})
383
 
 
384
  @app.post("/api/similar")
385
  def similar_search():
386
  payload = request.get_json(force=True) or {}
387
  text = (payload.get("text") or "").strip()
388
  limit = max(1, min(int(payload.get("limit", 20)), 100))
 
 
 
389
  if not text:
390
  return jsonify({"ok": False, "error": "text required"}), 400
391
- vec = embed_text_to_vec1536(text)
392
- vec_json = json.dumps(vec)
393
- sql = (
394
- "SELECT ps.object_id, ps.seq, ps.sentence, o.source, o.title, o.creator, "
395
- f"VEC_COSINE_DISTANCE(ps.embedding, CAST(%s AS VECTOR({VEC_DIM}))) AS distance "
396
- "FROM provenance_sentences ps "
397
- "JOIN objects o ON o.object_id = ps.object_id "
398
- "ORDER BY distance ASC "
399
- "LIMIT %s"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
400
  )
401
- with cursor() as cur:
402
- cur.execute(sql, (vec_json, limit))
403
- rows = cur.fetchall()
404
- return jsonify({"ok": True, "device": _DEVICE_INFO, "query": text, "data": rows})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
405
 
406
  @app.get("/api/vocab")
407
  def vocab():
 
8
  from google import genai
9
  from google.genai import types
10
 
11
+ from pymysql.err import OperationalErrorwarnings.filterwarnings("ignore")
12
 
13
  # ───────────────────────────────────────────────────────────────────────────────
14
  # CONFIG
 
381
  rows = cur.fetchall()
382
  return jsonify({"ok": True, "query": q, "data": rows})
383
 
384
+
385
  @app.post("/api/similar")
386
  def similar_search():
387
  payload = request.get_json(force=True) or {}
388
  text = (payload.get("text") or "").strip()
389
  limit = max(1, min(int(payload.get("limit", 20)), 100))
390
+ candidates = int(payload.get("candidates", max(200, limit * 10))) # ANN pre-topK
391
+ source_filter = (payload.get("source") or "").strip().upper() # e.g., "AIC"
392
+
393
  if not text:
394
  return jsonify({"ok": False, "error": "text required"}), 400
395
+
396
+ # Embed without NumPy path
397
+ try:
398
+ import torch
399
+ vec_t = _load_model().encode([text], batch_size=1, show_progress_bar=False, convert_to_tensor=True)
400
+ if isinstance(vec_t, torch.Tensor):
401
+ vec = vec_t[0].detach().cpu().tolist()
402
+ else:
403
+ vec = list(vec_t[0])
404
+ except Exception as e:
405
+ return jsonify({"ok": False, "error": f"embedding_unavailable: {e}"}), 503
406
+
407
+ vec_json = json.dumps(_pad(vec, VEC_DIM))
408
+
409
+ # Build query with explicit HNSW usage and staged join
410
+ where_src = "WHERE o.source = %s" if source_filter else ""
411
+ sql = f"""
412
+ WITH nn AS (
413
+ SELECT /*+ USE_INDEX(ps, hnsw_vec) */
414
+ ps.sent_id, ps.object_id, ps.seq, ps.sentence,
415
+ VEC_COSINE_DISTANCE(ps.embedding, CAST(%s AS VECTOR({VEC_DIM}))) AS distance
416
+ FROM provenance_sentences ps
417
+ ORDER BY distance
418
+ LIMIT %s
419
  )
420
+ SELECT nn.object_id, nn.seq, nn.sentence, o.source, o.title, o.creator, nn.distance
421
+ FROM nn
422
+ JOIN objects o ON o.object_id = nn.object_id
423
+ {where_src}
424
+ ORDER BY nn.distance
425
+ LIMIT %s
426
+ """
427
+
428
+ params = [vec_json, candidates]
429
+ if source_filter:
430
+ params.append(source_filter)
431
+ params.append(limit)
432
+
433
+ try:
434
+ with cursor() as cur:
435
+ cur.execute(sql, params)
436
+ rows = cur.fetchall()
437
+ return jsonify({"ok": True, "device": _DEVICE_INFO, "query": text, "data": rows,
438
+ "meta": {"limit": limit, "candidates": candidates, "source": source_filter or None}})
439
+ except OperationalError as e:
440
+ # TiDB OOM (1105) β†’ retry with smaller candidate set automatically
441
+ if e.args and e.args[0] == 1105 and candidates > max(100, limit * 4):
442
+ smaller = max(100, limit * 4)
443
+ params2 = [vec_json, smaller]
444
+ if source_filter:
445
+ params2.append(source_filter)
446
+ params2.append(limit)
447
+ try:
448
+ with cursor() as cur:
449
+ cur.execute(sql, params2)
450
+ rows = cur.fetchall()
451
+ return jsonify({"ok": True, "device": _DEVICE_INFO, "query": text, "data": rows,
452
+ "meta": {"limit": limit, "candidates": smaller, "source": source_filter or None,
453
+ "note": "retried with smaller candidate set"}})
454
+ except Exception as e2:
455
+ return jsonify({"ok": False, "error": f"oom_retry_failed: {e2}"}), 500
456
+ # Not OOM or still failed
457
+ return jsonify({"ok": False, "error": f"query_failed: {e}"}), 500
458
+
459
 
460
  @app.get("/api/vocab")
461
  def vocab():