rairo commited on
Commit
326f3b6
Β·
verified Β·
1 Parent(s): 7d1d875

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +110 -37
main.py CHANGED
@@ -9,6 +9,7 @@ from google import genai
9
  from google.genai import types
10
 
11
  from pymysql.err import OperationalError
 
12
  warnings.filterwarnings("ignore")
13
 
14
  # ───────────────────────────────────────────────────────────────────────────────
@@ -52,20 +53,13 @@ app = Flask(__name__)
52
  CORS(app)
53
 
54
  # ───────────────────────────────────────────────────────────────────────────────
55
- # DB CONNECTION (autocommit + TLS + auto-reconnect)
56
  # ───────────────────────────────────────────────────────────────────────────────
57
- _CONN = None
58
 
59
- def _connect():
60
- """Enhanced connection with better error handling"""
61
- global _CONN
62
- try:
63
- if _CONN:
64
- _CONN.close()
65
- except Exception:
66
- pass
67
-
68
- _CONN = pymysql.connect(
69
  host=TIDB_HOST,
70
  port=TIDB_PORT,
71
  user=TIDB_USER,
@@ -77,36 +71,90 @@ def _connect():
77
  autocommit=True,
78
  charset="utf8mb4",
79
  cursorclass=pymysql.cursors.DictCursor,
80
- # Add these timeouts:
81
  connect_timeout=10,
82
- read_timeout=30,
83
  write_timeout=30,
 
 
 
84
  )
85
 
86
- def _ensure_conn():
87
- global _CONN
 
 
88
  max_retries = 3
 
89
  for attempt in range(max_retries):
90
  try:
91
- if _CONN is None:
92
- _connect()
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  else:
94
- _CONN.ping(reconnect=False) # Test without auto-reconnect first
95
- return _CONN
96
  except Exception as e:
97
- log.warning(f"Connection attempt {attempt + 1} failed: {e}")
98
- _CONN = None
99
- if attempt < max_retries - 1:
100
- time.sleep(0.5 * (attempt + 1)) # Backoff
101
- else:
102
- raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
- @contextmanager
105
- def cursor():
106
- """DictCursor with auto-ping; use in each route."""
107
- conn = _ensure_conn()
108
- with conn.cursor() as cur:
109
- yield cur
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
  # ───────────────────────────────────────────────────────────────────────────────
112
  # EMBEDDINGS (lazy-load; same model as ingest; pad to 1536)
@@ -298,23 +346,41 @@ def root():
298
  return jsonify({"ok": True, "service": "provenance-radar-api", "device": _DEVICE_INFO})
299
 
300
  @app.get("/api/health")
 
301
  def health():
302
  try:
 
303
  with cursor() as cur:
304
  cur.execute("SELECT COUNT(*) AS c FROM objects"); objects = cur.fetchone()["c"]
305
  cur.execute("SELECT COUNT(*) AS c FROM provenance_sentences"); sentences = cur.fetchone()["c"]
306
  cur.execute("SELECT COUNT(*) AS c FROM risk_signals"); risks = cur.fetchone()["c"]
307
- return jsonify({"ok": True, "device": _DEVICE_INFO, "counts": {
308
- "objects": objects, "sentences": sentences, "risk_signals": risks}})
 
 
 
 
 
 
 
 
 
 
 
309
  except Exception as e:
310
  log.exception("health failed")
311
- return jsonify({"ok": False, "error": str(e)}), 500
 
 
 
 
312
 
313
  @app.get("/api/policy/windows")
314
  def policy_windows():
315
  return jsonify({"ok": True, "windows": POLICY_WINDOWS})
316
 
317
  @app.get("/api/leads")
 
318
  def leads():
319
  limit = max(1, min(int(request.args.get("limit", 50)), 200))
320
  min_score = float(request.args.get("min_score", 0))
@@ -335,6 +401,7 @@ def leads():
335
  return jsonify({"ok": True, "data": rows})
336
 
337
  @app.get("/api/object/<int:object_id>")
 
338
  def object_detail(object_id: int):
339
  with cursor() as cur:
340
  cur.execute("SELECT * FROM objects WHERE object_id=%s", (object_id,))
@@ -352,6 +419,7 @@ def object_detail(object_id: int):
352
  return jsonify({"ok": True, "object": obj, "sentences": sents, "events": events, "risks": risks})
353
 
354
  @app.get("/api/graph/<int:object_id>")
 
355
  def graph(object_id: int):
356
  with cursor() as cur:
357
  cur.execute("SELECT object_id, source, title FROM objects WHERE object_id=%s", (object_id,))
@@ -365,6 +433,7 @@ def graph(object_id: int):
365
  return jsonify({"ok": True, **build_graph_from_events(obj, events)})
366
 
367
  @app.get("/api/timeline/<int:object_id>")
 
368
  def timeline(object_id: int):
369
  with cursor() as cur:
370
  cur.execute("SELECT seq, sentence FROM provenance_sentences WHERE object_id=%s ORDER BY seq", (object_id,))
@@ -377,6 +446,7 @@ def timeline(object_id: int):
377
  return jsonify({"ok": True, "items": items})
378
 
379
  @app.get("/api/keyword")
 
380
  def keyword_search():
381
  q = (request.args.get("q") or "").strip()
382
  limit = max(1, min(int(request.args.get("limit", 50)), 200))
@@ -396,6 +466,7 @@ def keyword_search():
396
 
397
 
398
  @app.post("/api/similar")
 
399
  def similar_search():
400
  payload = request.get_json(force=True) or {}
401
  text = (payload.get("text") or "").strip()
@@ -471,6 +542,7 @@ def similar_search():
471
 
472
 
473
  @app.get("/api/vocab")
 
474
  def vocab():
475
  field = (request.args.get("field") or "").strip().lower()
476
  limit = max(1, min(int(request.args.get("limit", 100)), 500))
@@ -490,6 +562,7 @@ def vocab():
490
  # ── Gemini-powered explanations ────────────────────────────────────────────────
491
 
492
  @app.get("/api/explain/object/<int:object_id>")
 
493
  def explain_object(object_id: int):
494
  """Generate a concise, policy-aware research note for an object."""
495
  with cursor() as cur:
@@ -533,7 +606,7 @@ def explain_text():
533
  return jsonify({"ok": False, "error": "text required"}), 400
534
  sys = ("Explain this text as a provenance note for curators. "
535
  "Be precise and cautious; highlight possible red flags tied to 1933–1945 and post-1970 export rules.")
536
- prompt = f"Explain and contextualize this provenance fragment:\n\nβ€œ{sentence}”."
537
  text = gemini_explain(prompt, sys=sys)
538
  return jsonify({"ok": True, "model": EXPLAIN_MODEL, "explanation": text})
539
 
@@ -542,4 +615,4 @@ def explain_text():
542
  # ───────────────────────────────────────────────────────────────────────────────
543
  if __name__ == "__main__":
544
  port = int(os.environ.get("PORT", "7860"))
545
- app.run(host="0.0.0.0", port=port, debug=False)
 
9
  from google.genai import types
10
 
11
  from pymysql.err import OperationalError
12
+ import threading
13
  warnings.filterwarnings("ignore")
14
 
15
  # ───────────────────────────────────────────────────────────────────────────────
 
53
  CORS(app)
54
 
55
  # ───────────────────────────────────────────────────────────────────────────────
56
+ # DB CONNECTION (refactored for better connection management)
57
  # ───────────────────────────────────────────────────────────────────────────────
58
+ _connection_lock = threading.Lock()
59
 
60
+ def _create_connection():
61
+ """Create a new database connection with optimized settings"""
62
+ return pymysql.connect(
 
 
 
 
 
 
 
63
  host=TIDB_HOST,
64
  port=TIDB_PORT,
65
  user=TIDB_USER,
 
71
  autocommit=True,
72
  charset="utf8mb4",
73
  cursorclass=pymysql.cursors.DictCursor,
 
74
  connect_timeout=10,
75
+ read_timeout=60, # Increased for vector operations
76
  write_timeout=30,
77
+ # TiDB-specific optimizations:
78
+ init_command="SET SESSION sql_mode='STRICT_TRANS_TABLES,NO_ZERO_DATE,NO_ZERO_IN_DATE,ERROR_FOR_DIVISION_BY_ZERO'",
79
+ client_flag=pymysql.constants.CLIENT.MULTI_STATEMENTS,
80
  )
81
 
82
+ @contextmanager
83
+ def cursor():
84
+ """Create a fresh connection for each request context with retry logic"""
85
+ conn = None
86
  max_retries = 3
87
+
88
  for attempt in range(max_retries):
89
  try:
90
+ conn = _create_connection()
91
+ with conn.cursor() as cur:
92
+ yield cur
93
+ break
94
+ except (OperationalError, pymysql.err.InternalError) as e:
95
+ if conn:
96
+ try:
97
+ conn.close()
98
+ except Exception:
99
+ pass
100
+ conn = None
101
+
102
+ if attempt == max_retries - 1:
103
+ log.error(f"Database connection failed after {max_retries} attempts: {e}")
104
+ raise
105
  else:
106
+ log.warning(f"Database connection failed (attempt {attempt + 1}): {e}")
107
+ time.sleep(0.5 * (attempt + 1)) # Exponential backoff
108
  except Exception as e:
109
+ if conn:
110
+ try:
111
+ conn.close()
112
+ except Exception:
113
+ pass
114
+ log.error(f"Database connection failed: {e}")
115
+ raise
116
+ finally:
117
+ if conn:
118
+ try:
119
+ conn.close()
120
+ except Exception:
121
+ pass
122
+
123
+ def with_db_retry(func):
124
+ """Decorator to retry database operations on connection failures"""
125
+ def wrapper(*args, **kwargs):
126
+ max_retries = 3
127
+ for attempt in range(max_retries):
128
+ try:
129
+ return func(*args, **kwargs)
130
+ except (OperationalError, pymysql.err.InternalError) as e:
131
+ if attempt == max_retries - 1:
132
+ log.error(f"Database operation failed after {max_retries} attempts: {e}")
133
+ raise
134
+ log.warning(f"Database operation failed (attempt {attempt + 1}): {e}")
135
+ time.sleep(0.5 * (attempt + 1))
136
+ return wrapper
137
 
138
+ # ───────────────────────────────────────────────────────────────────────────────
139
+ # ERROR HANDLERS
140
+ # ─────────────────────────────────────────────────��─────────────────────────────
141
+ @app.errorhandler(OperationalError)
142
+ def handle_db_error(e):
143
+ log.error(f"Database error: {e}")
144
+ return jsonify({
145
+ "ok": False,
146
+ "error": "database_unavailable",
147
+ "message": "Database connection issue. Please try again."
148
+ }), 503
149
+
150
+ @app.errorhandler(pymysql.err.InternalError)
151
+ def handle_internal_error(e):
152
+ log.error(f"Database internal error: {e}")
153
+ return jsonify({
154
+ "ok": False,
155
+ "error": "database_error",
156
+ "message": "Database operation failed. Please try again."
157
+ }), 500
158
 
159
  # ───────────────────────────────────────────────────────────────────────────────
160
  # EMBEDDINGS (lazy-load; same model as ingest; pad to 1536)
 
346
  return jsonify({"ok": True, "service": "provenance-radar-api", "device": _DEVICE_INFO})
347
 
348
  @app.get("/api/health")
349
+ @with_db_retry
350
  def health():
351
  try:
352
+ start_time = time.time()
353
  with cursor() as cur:
354
  cur.execute("SELECT COUNT(*) AS c FROM objects"); objects = cur.fetchone()["c"]
355
  cur.execute("SELECT COUNT(*) AS c FROM provenance_sentences"); sentences = cur.fetchone()["c"]
356
  cur.execute("SELECT COUNT(*) AS c FROM risk_signals"); risks = cur.fetchone()["c"]
357
+
358
+ db_latency = round((time.time() - start_time) * 1000, 2)
359
+
360
+ return jsonify({
361
+ "ok": True,
362
+ "device": _DEVICE_INFO,
363
+ "db_latency_ms": db_latency,
364
+ "counts": {
365
+ "objects": objects,
366
+ "sentences": sentences,
367
+ "risk_signals": risks
368
+ }
369
+ })
370
  except Exception as e:
371
  log.exception("health failed")
372
+ return jsonify({
373
+ "ok": False,
374
+ "error": str(e),
375
+ "db_status": "unavailable"
376
+ }), 503
377
 
378
  @app.get("/api/policy/windows")
379
  def policy_windows():
380
  return jsonify({"ok": True, "windows": POLICY_WINDOWS})
381
 
382
  @app.get("/api/leads")
383
+ @with_db_retry
384
  def leads():
385
  limit = max(1, min(int(request.args.get("limit", 50)), 200))
386
  min_score = float(request.args.get("min_score", 0))
 
401
  return jsonify({"ok": True, "data": rows})
402
 
403
  @app.get("/api/object/<int:object_id>")
404
+ @with_db_retry
405
  def object_detail(object_id: int):
406
  with cursor() as cur:
407
  cur.execute("SELECT * FROM objects WHERE object_id=%s", (object_id,))
 
419
  return jsonify({"ok": True, "object": obj, "sentences": sents, "events": events, "risks": risks})
420
 
421
  @app.get("/api/graph/<int:object_id>")
422
+ @with_db_retry
423
  def graph(object_id: int):
424
  with cursor() as cur:
425
  cur.execute("SELECT object_id, source, title FROM objects WHERE object_id=%s", (object_id,))
 
433
  return jsonify({"ok": True, **build_graph_from_events(obj, events)})
434
 
435
  @app.get("/api/timeline/<int:object_id>")
436
+ @with_db_retry
437
  def timeline(object_id: int):
438
  with cursor() as cur:
439
  cur.execute("SELECT seq, sentence FROM provenance_sentences WHERE object_id=%s ORDER BY seq", (object_id,))
 
446
  return jsonify({"ok": True, "items": items})
447
 
448
  @app.get("/api/keyword")
449
+ @with_db_retry
450
  def keyword_search():
451
  q = (request.args.get("q") or "").strip()
452
  limit = max(1, min(int(request.args.get("limit", 50)), 200))
 
466
 
467
 
468
  @app.post("/api/similar")
469
+ @with_db_retry
470
  def similar_search():
471
  payload = request.get_json(force=True) or {}
472
  text = (payload.get("text") or "").strip()
 
542
 
543
 
544
  @app.get("/api/vocab")
545
+ @with_db_retry
546
  def vocab():
547
  field = (request.args.get("field") or "").strip().lower()
548
  limit = max(1, min(int(request.args.get("limit", 100)), 500))
 
562
  # ── Gemini-powered explanations ────────────────────────────────────────────────
563
 
564
  @app.get("/api/explain/object/<int:object_id>")
565
+ @with_db_retry
566
  def explain_object(object_id: int):
567
  """Generate a concise, policy-aware research note for an object."""
568
  with cursor() as cur:
 
606
  return jsonify({"ok": False, "error": "text required"}), 400
607
  sys = ("Explain this text as a provenance note for curators. "
608
  "Be precise and cautious; highlight possible red flags tied to 1933–1945 and post-1970 export rules.")
609
+ prompt = f"Explain and contextualize this provenance fragment:\n\n"{sentence}"."
610
  text = gemini_explain(prompt, sys=sys)
611
  return jsonify({"ok": True, "model": EXPLAIN_MODEL, "explanation": text})
612
 
 
615
  # ───────────────────────────────────────────────────────────────────────────────
616
  if __name__ == "__main__":
617
  port = int(os.environ.get("PORT", "7860"))
618
+ app.run(host="0.0.0.0", port=port, debug=False)