Spaces:
Paused
Paused
| import os, json, logging, warnings, time, certifi, pymysql, requests | |
| from contextlib import contextmanager | |
| from datetime import date | |
| from flask import Flask, request, jsonify | |
| from flask_cors import CORS | |
| from datetime import date, datetime | |
| # ---- Optional Google GenAI (Gemini) ---- | |
| from google import genai | |
| from google.genai import types | |
| from pymysql.err import OperationalError | |
| import threading | |
| warnings.filterwarnings("ignore") | |
| # ββ NEW: lightweight event inference from sentences βββββββββββββββββββββββββββ | |
| import re | |
| from typing import List, Dict, Any, Optional | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # CONFIG | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| DB_NAME = os.getenv("TIDB_DB") | |
| TIDB_HOST = os.getenv("TIDB_HOST") | |
| TIDB_PORT = int(os.getenv("TIDB_PORT")) | |
| TIDB_USER = os.getenv("TIDB_USER") | |
| TIDB_PASS = os.getenv("TIDB_PASS") | |
| VEC_DIM = int(os.getenv("VEC_DIM", "1536")) | |
| EMBED_MODEL = os.getenv("EMBED_MODEL", "sentence-transformers/all-MiniLM-L6-v2") | |
| USE_GPU = os.getenv("USE_GPU", "0") == "1" # Spaces are usually CPU; works either way | |
| # Policy windows (server is single source of truth for the client) | |
| POLICY_WINDOWS = [ | |
| { | |
| "code": "NAZI_ERA", | |
| "label": "Washington Conference Principles (1933β1945)", | |
| "from": "1933-01-01", | |
| "to": "1945-12-31", | |
| "ref": "https://www.state.gov/washington-conference-principles-on-nazi-confiscated-art" | |
| }, | |
| { | |
| "code": "UNESCO_1970", | |
| "label": "UNESCO 1970 Convention", | |
| "from": "1970-11-14", | |
| "to": None, | |
| "ref": "https://www.unesco.org/en/legal-affairs/convention-means-prohibiting-and-preventing-illicit-import-export-and-transfer-ownership-cultural" | |
| } | |
| ] | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # APP + LOGGING | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| logging.basicConfig(level=logging.INFO) | |
| log = logging.getLogger("provenance-api") | |
| app = Flask(__name__) | |
| CORS(app) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # DB CONNECTION (refactored for better connection management) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _connection_lock = threading.Lock() | |
| def _create_connection(): | |
| """Create a new database connection with optimized settings""" | |
| return pymysql.connect( | |
| host=TIDB_HOST, | |
| port=TIDB_PORT, | |
| user=TIDB_USER, | |
| password=TIDB_PASS, | |
| database=DB_NAME, | |
| ssl={"ca": certifi.where()}, | |
| ssl_verify_cert=True, | |
| ssl_verify_identity=True, | |
| autocommit=True, | |
| charset="utf8mb4", | |
| cursorclass=pymysql.cursors.DictCursor, | |
| connect_timeout=10, | |
| read_timeout=60, # Increased for vector operations | |
| write_timeout=30, | |
| # TiDB-specific optimizations: | |
| init_command="SET SESSION sql_mode='STRICT_TRANS_TABLES,NO_ZERO_DATE,NO_ZERO_IN_DATE,ERROR_FOR_DIVISION_BY_ZERO'", | |
| client_flag=pymysql.constants.CLIENT.MULTI_STATEMENTS, | |
| ) | |
| def cursor(): | |
| """Create a fresh connection for each request context with retry logic""" | |
| conn = None | |
| max_retries = 3 | |
| for attempt in range(max_retries): | |
| try: | |
| conn = _create_connection() | |
| with conn.cursor() as cur: | |
| yield cur | |
| break | |
| except (OperationalError, pymysql.err.InternalError) as e: | |
| if conn: | |
| try: | |
| conn.close() | |
| except Exception: | |
| pass | |
| conn = None | |
| if attempt == max_retries - 1: | |
| log.error(f"Database connection failed after {max_retries} attempts: {e}") | |
| raise | |
| else: | |
| log.warning(f"Database connection failed (attempt {attempt + 1}): {e}") | |
| time.sleep(0.5 * (attempt + 1)) # Exponential backoff | |
| except Exception as e: | |
| if conn: | |
| try: | |
| conn.close() | |
| except Exception: | |
| pass | |
| log.error(f"Database connection failed: {e}") | |
| raise | |
| finally: | |
| if conn: | |
| try: | |
| conn.close() | |
| except Exception: | |
| pass | |
| def with_db_retry(func): | |
| """Decorator to retry database operations on connection failures""" | |
| import functools | |
| # This preserves the original function's metadata | |
| def wrapper(*args, **kwargs): | |
| max_retries = 3 | |
| for attempt in range(max_retries): | |
| try: | |
| return func(*args, **kwargs) | |
| except (OperationalError, pymysql.err.InternalError) as e: | |
| if attempt == max_retries - 1: | |
| log.error(f"Database operation failed after {max_retries} attempts: {e}") | |
| raise | |
| log.warning(f"Database operation failed (attempt {attempt + 1}): {e}") | |
| time.sleep(0.5 * (attempt + 1)) | |
| return wrapper | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # ERROR HANDLERS | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def handle_db_error(e): | |
| log.error(f"Database error: {e}") | |
| return jsonify({ | |
| "ok": False, | |
| "error": "database_unavailable", | |
| "message": "Database connection issue. Please try again." | |
| }), 503 | |
| def handle_internal_error(e): | |
| log.error(f"Database internal error: {e}") | |
| return jsonify({ | |
| "ok": False, | |
| "error": "database_error", | |
| "message": "Database operation failed. Please try again." | |
| }), 500 | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # EMBEDDINGS (lazy-load; same model as ingest; pad to 1536) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _MODEL = None | |
| _DEVICE_INFO = "cpu" | |
| def _pad(vec, dim=VEC_DIM): | |
| return vec[:dim] + [0.0] * max(0, dim - len(vec)) | |
| def _load_model(): | |
| global _MODEL, _DEVICE_INFO | |
| if _MODEL is not None: | |
| return _MODEL | |
| if USE_GPU: | |
| try: | |
| import torch | |
| if torch.cuda.is_available(): | |
| _DEVICE_INFO = "cuda" | |
| except Exception: | |
| _DEVICE_INFO = "cpu" | |
| from sentence_transformers import SentenceTransformer | |
| _MODEL = SentenceTransformer(EMBED_MODEL, device=_DEVICE_INFO) | |
| log.info(f"Loaded embedding model on '{_DEVICE_INFO}': {EMBED_MODEL}") | |
| return _MODEL | |
| def embed_text_to_vec1536(text: str): | |
| model = _load_model() | |
| # Use Torch tensors to avoid NumPy code path entirely | |
| import torch | |
| t = model.encode([text], batch_size=1, show_progress_bar=False, convert_to_tensor=True) | |
| if isinstance(t, torch.Tensor): | |
| vec = t[0].detach().cpu().tolist() | |
| else: | |
| # very defensive fallback | |
| vec = list(t[0]) | |
| return _pad(vec, VEC_DIM) | |
| def to_iso(d): | |
| """Return YYYY-MM-DD for date/datetime/str; None for empty.""" | |
| if d is None: | |
| return None | |
| if isinstance(d, (date, datetime)): | |
| return d.isoformat()[:10] | |
| if isinstance(d, str): | |
| return d[:10] if d else None | |
| # fallback | |
| try: | |
| return str(d)[:10] | |
| except Exception: | |
| return None | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # GEMINI (explanations / descriptions) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| GEMINI_KEY = os.environ.get("Gemini") | |
| _gclient = None | |
| def _gemini(): | |
| global _gclient | |
| if _gclient is not None: | |
| return _gclient | |
| if not GEMINI_KEY: | |
| return None | |
| try: | |
| _gclient = genai.Client(api_key=GEMINI_KEY) | |
| log.info("Gemini client initialized.") | |
| return _gclient | |
| except Exception as e: | |
| log.warning(f"Gemini init failed: {e}") | |
| return None | |
| EXPLAIN_MODEL = "gemini-2.0-flash" | |
| def gemini_explain(prompt: str, sys: str = None, model: str = EXPLAIN_MODEL) -> str: | |
| g = _gemini() | |
| if g is None: | |
| # Graceful fallback so the API still works without a key | |
| return "(Gemini not configured) " + prompt[:180] | |
| # chat-style to mirror your original pattern | |
| chat = g.chats.create(model=model) | |
| # Add a light system preamble for style/constraints | |
| if sys: | |
| chat.send_message(f"[SYSTEM]\n{sys}") | |
| resp = chat.send_message(prompt) | |
| return getattr(resp, "text", "").strip() or "" | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # UTIL: Build risk scores, graph & timeline from events (+ risk overlays) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # | |
| # Targets: | |
| # raw 100 -> ~55 | |
| # raw 200 -> ~80 | |
| # raw 2000 -> ~99 (slow approach to 99 beyond this) | |
| # BLOCK 1 β Helpers (drop-in) | |
| # - Piecewise normalize_risk() curve | |
| # - _to_float() coercion | |
| # - _apply_normalized_risk_inplace(): overwrites 'risk_score' and keeps 'risk_score_raw' | |
| import math | |
| from decimal import Decimal | |
| def _to_float(x): | |
| if x is None: return None | |
| if isinstance(x, (int, float)): return float(x) | |
| if isinstance(x, Decimal): return float(x) | |
| if isinstance(x, str): | |
| try: return float(x.strip().replace("%","")) | |
| except Exception: return None | |
| try: return float(x) | |
| except Exception: return None | |
| def _piecewise_0_99_from_percent(pct: float) -> float: | |
| """Piecewise curve on a 0β99 scale using 'percent' inputs (100, 200, ...).""" | |
| x = max(float(pct), 0.0) | |
| if x <= 100.0: | |
| out = 55.0 * ((x / 100.0) ** 0.7) # ~55 at 100 | |
| elif x <= 200.0: | |
| out = 55.0 + 25.0 * (((x - 100.0) / 100.0) ** 0.8) # 55β80 between 100β200 | |
| else: | |
| k = math.log(100.0) / 1800.0 # ~98.8 at 2000 | |
| out = 99.0 - 19.0 * math.exp(-k * (x - 200.0)) | |
| return max(0.0, min(out, 99.0)) | |
| def normalize_risk(score_ratio: float) -> float: | |
| """ | |
| INPUT: raw ratio (1.0=100%, 2.0=200%, 6.0=600%) | |
| OUTPUT: normalized ratio on 0β1 scale (e.g., 0.8 for 80%) | |
| """ | |
| r = _to_float(score_ratio) | |
| if r is None: return None | |
| pct_in = r * 100.0 # convert to percent domain for mapping | |
| pct_out = _piecewise_0_99_from_percent(pct_in) | |
| return round(pct_out / 100.0, 6) # send back as 0β1 for the UI | |
| def _apply_normalized_risk_inplace(row: dict): | |
| if not isinstance(row, dict): | |
| return | |
| raw_ratio = _to_float(row.get("risk_score")) | |
| if raw_ratio is None: | |
| return | |
| norm_ratio = normalize_risk(raw_ratio) # 0β1 | |
| norm_0_99 = None if norm_ratio is None else round(norm_ratio * 100.0, 2) | |
| row["risk_score_raw"] = raw_ratio # raw ratio (e.g., 2.0) | |
| row["risk_score_norm_0_99"] = norm_0_99 # 0β99 reference (e.g., 80.0) | |
| row["risk_score"] = norm_ratio # **what client already uses** (0β1) | |
| row["risk_score_normalized"]= norm_ratio # alias if client checks this too | |
| EVENT_VERBS = { | |
| "sold": "SOLD", | |
| "purchased": "PURCHASED", | |
| "bought": "PURCHASED", | |
| "acquired": "ACQUIRED", | |
| "donated": "DONATED", | |
| "gifted": "DONATED", | |
| "bequeathed": "BEQUEATHED", | |
| "consigned": "CONSIGNED", | |
| "exhibited": "EXHIBITED", | |
| "exported": "EXPORTED", | |
| "imported": "IMPORTED", | |
| } | |
| YEAR_RE = re.compile(r"\b(1[6-9]\d{2}|20\d{2})\b") # 1600β2099 | |
| def _clean(s: Optional[str]) -> Optional[str]: | |
| if not s: return None | |
| s = re.sub(r"\s+", " ", s).strip(" ,.;:-ββ") | |
| return s or None | |
| def _infer_from_sentence(txt: str) -> Optional[Dict[str, Any]]: | |
| """ | |
| Very pragmatic patterns that cover most catalogue phrasing: | |
| - 'sold to X, <place>, 2000' | |
| - 'sold to X, by 2000' | |
| - 'purchased from Y in 1965' | |
| - 'donated by X, <place>, 1971' | |
| Returns a dict compatible with provenance_events rows. | |
| """ | |
| if not txt: | |
| return None | |
| low = txt.lower() | |
| # find verb | |
| verb = next((EVENT_VERBS[v] for v in EVENT_VERBS if v in low), None) | |
| if not verb: | |
| return None | |
| # pull a year (prefers the last year in the string) | |
| years = YEAR_RE.findall(txt) | |
| year = years[-1] if years else None | |
| actor = None | |
| place = None | |
| # Common pattern: 'sold to X, place, 2000' | |
| m = re.search(r"\b(sold|purchased|bought|acquired|donated|gifted|bequeathed|consigned)\s+(to|by|from)\s+(.*)$", low) | |
| if m: | |
| # Take the fragment after 'to/by/from' | |
| frag = txt[m.end(2)+1:].strip() | |
| # Trim trailing year or 'by 2000' | |
| frag = re.sub(r"(,\s*)?(by\s*)?\b(1[6-9]\d{2}|20\d{2})\b.*$", "", frag, flags=re.IGNORECASE).strip(" ,.;") | |
| # Split on commas: first token is actor; the rest (if any) is place | |
| parts = [p.strip() for p in re.split(r",(?![^()]*\))", frag) if p.strip()] | |
| if parts: | |
| actor = parts[0] | |
| if len(parts) > 1: | |
| place = ", ".join(parts[1:]) | |
| # Fallback simple 'sold to X' without commas | |
| if not actor: | |
| m2 = re.search(r"\bsold\s+to\s+([^,.;]+)", low) | |
| if m2: | |
| actor = _clean(txt[m2.start(1):m2.end(1)]) | |
| return { | |
| "event_type": verb, | |
| "date_from": f"{year}-01-01" if year else None, | |
| "date_to": None, | |
| "place": _clean(place), | |
| "actor": _clean(actor), | |
| "method": None, | |
| "source_ref": "inferred:sentence" | |
| } | |
| def infer_events_from_sentences(sentences: List[Dict[str, Any]]) -> List[Dict[str, Any]]: | |
| out: List[Dict[str, Any]] = [] | |
| for s in sentences: | |
| ev = _infer_from_sentence(s.get("sentence", "")) | |
| if ev and (ev.get("actor") or ev.get("place")): | |
| ev["seq"] = s.get("seq") | |
| out.append(ev) | |
| # Deduplicate (actor+place+event_type+date_from) | |
| seen = set() | |
| uniq = [] | |
| for e in out: | |
| key = (e.get("actor"), e.get("place"), e.get("event_type"), e.get("date_from")) | |
| if key in seen: | |
| continue | |
| seen.add(key) | |
| uniq.append(e) | |
| return uniq | |
| # ββ OPTIONAL: simple geocode cache for map pins βββββββββββββββββββββββββββββββ | |
| def geocode_place_cached(place: str): | |
| """Cache in DB: places_cache(place TEXT PRIMARY KEY, lat DOUBLE, lon DOUBLE, updated_at TIMESTAMP)""" | |
| if not place: | |
| return None | |
| with cursor() as cur: | |
| cur.execute("CREATE TABLE IF NOT EXISTS places_cache (place VARCHAR(255) PRIMARY KEY, lat DOUBLE, lon DOUBLE, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)") | |
| cur.execute("SELECT lat, lon FROM places_cache WHERE place=%s", (place,)) | |
| row = cur.fetchone() | |
| if row and row.get("lat") is not None and row.get("lon") is not None: | |
| return row | |
| # Try Nominatim (best effort). If outbound HTTP is blocked, just skip. | |
| try: | |
| r = requests.get( | |
| "https://nominatim.openstreetmap.org/search", | |
| params={"q": place, "format": "json", "limit": 1}, | |
| headers={"User-Agent": "provenance-radar/1.0"}, | |
| timeout=6, | |
| ) | |
| j = r.json() | |
| if j: | |
| lat, lon = float(j[0]["lat"]), float(j[0]["lon"]) | |
| else: | |
| lat, lon = None, None | |
| except Exception: | |
| lat, lon = None, None | |
| with cursor() as cur: | |
| cur.execute( | |
| "INSERT INTO places_cache (place, lat, lon) VALUES (%s,%s,%s) ON DUPLICATE KEY UPDATE lat=VALUES(lat), lon=VALUES(lon), updated_at=CURRENT_TIMESTAMP", | |
| (place, lat, lon), | |
| ) | |
| if lat is None or lon is None: | |
| return None | |
| return {"lat": lat, "lon": lon} | |
| def _policy_hits_for_date(d: str): | |
| """Return policy codes a given ISO date falls into.""" | |
| if not d: | |
| return [] | |
| hits = [] | |
| for w in POLICY_WINDOWS: | |
| start_ok = (d >= w["from"]) if w["from"] else True | |
| end_ok = (d <= w["to"]) if w["to"] else True | |
| if start_ok and end_ok: | |
| hits.append(w["code"]) | |
| return hits | |
| def build_graph_from_events(obj_row, events): | |
| """Cytoscape.js-style graph: nodes+edges.""" | |
| nodes = [] | |
| edges = [] | |
| # center object node | |
| onode = { | |
| "id": f"obj:{obj_row['object_id']}", | |
| "label": f"{obj_row.get('title') or 'Untitled'} ({obj_row.get('source')})", | |
| "type": "object" | |
| } | |
| nodes_map = {onode["id"]: onode} | |
| def add_node(kind, label): | |
| if not label: | |
| return None | |
| nid = f"{kind}:{label}" | |
| if nid not in nodes_map: | |
| nodes_map[nid] = {"id": nid, "label": label, "type": kind} | |
| return nid | |
| for ev in events: | |
| actor = ev.get("actor") | |
| place = ev.get("place") | |
| etype = ev.get("event_type") or "UNKNOWN" | |
| d_iso = to_iso(ev.get("date_from")) | |
| actor_id = add_node("actor", actor) if actor else None | |
| place_id = add_node("place", place) if place else None | |
| # Edge semantics: actor -> object; place is context (not endpoint) | |
| if actor_id: | |
| edges.append({ | |
| "source": actor_id, | |
| "target": onode["id"], | |
| "label": etype, | |
| "date": d_iso, | |
| "weight": 1.0, # client may recompute with risk overlays | |
| "source_ref": ev.get("source_ref"), | |
| "policy": _policy_hits_for_date(d_iso) | |
| }) | |
| # Optional: object -> place (to visualize locations) | |
| if place_id and place: | |
| edges.append({ | |
| "source": onode["id"], | |
| "target": place_id, | |
| "label": "LOCATED", | |
| "date": d_iso, | |
| "weight": 0.5, | |
| "source_ref": ev.get("source_ref"), | |
| "policy": _policy_hits_for_date(d_iso) | |
| }) | |
| return {"nodes": list(nodes_map.values()), "edges": edges} | |
| def build_timeline_from_events_and_sentences(events, sentences): | |
| """Simple list items for any timeline widget.""" | |
| items = [] | |
| s_by_seq = {s["seq"]: s["sentence"] for s in sentences} | |
| for ev in events: | |
| start = to_iso(ev.get("date_from")) | |
| end = to_iso(ev.get("date_to")) | |
| title = ev.get("event_type") or "Event" | |
| txt = None | |
| # Try to pull the nearest sentence by seq if present | |
| # (ingest stored seq starting at 0) | |
| for k in (0, 1, 2, 3): | |
| if k in s_by_seq: | |
| txt = s_by_seq[k]; break | |
| items.append({ | |
| "title": title, | |
| "start_date": start, | |
| "end_date": end, | |
| "text": txt or "", | |
| "source_ref": ev.get("source_ref") | |
| }) | |
| return items | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # ROUTES | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def root(): | |
| return jsonify({"ok": True, "service": "provenance-radar-api", "device": _DEVICE_INFO}) | |
| def health(): | |
| try: | |
| start_time = time.time() | |
| with cursor() as cur: | |
| cur.execute("SELECT COUNT(*) AS c FROM objects"); objects = cur.fetchone()["c"] | |
| cur.execute("SELECT COUNT(*) AS c FROM provenance_sentences"); sentences = cur.fetchone()["c"] | |
| cur.execute("SELECT COUNT(*) AS c FROM risk_signals"); risks = cur.fetchone()["c"] | |
| db_latency = round((time.time() - start_time) * 1000, 2) | |
| return jsonify({ | |
| "ok": True, | |
| "device": _DEVICE_INFO, | |
| "db_latency_ms": db_latency, | |
| "counts": { | |
| "objects": objects, | |
| "sentences": sentences, | |
| "risk_signals": risks | |
| } | |
| }) | |
| except Exception as e: | |
| log.exception("health failed") | |
| return jsonify({ | |
| "ok": False, | |
| "error": str(e), | |
| "db_status": "unavailable" | |
| }), 503 | |
| def policy_windows(): | |
| return jsonify({"ok": True, "windows": POLICY_WINDOWS}) | |
| def get_leads(): | |
| limit = max(1, min(int(request.args.get("limit", 50)), 200)) | |
| min_score = float(request.args.get("min_score", 0)) | |
| source = request.args.get("source") | |
| sql = ( | |
| "SELECT object_id, source, title, creator, risk_score, top_signals " | |
| "FROM flagged_leads WHERE risk_score >= %s " | |
| ) | |
| args = [min_score] | |
| if source: | |
| sql += " AND source = %s " | |
| args.append(source) | |
| sql += " LIMIT %s" | |
| args.append(limit) | |
| with cursor() as cur: | |
| cur.execute(sql, args) | |
| rows = cur.fetchall() | |
| for r in rows: | |
| _apply_normalized_risk_inplace(r) | |
| log.info("[RISK] /api/leads called | fetched=%s limit=%s min_score=%s source=%s", | |
| len(rows), limit, min_score, source or "ALL") | |
| for i, r in enumerate(rows[:5], start=1): | |
| raw_ratio = _to_float(r.get("risk_score_raw")) | |
| raw_pct = None if raw_ratio is None else round(raw_ratio * 100.0, 2) | |
| norm_ratio= _to_float(r.get("risk_score")) # 0β1 | |
| norm_pct = None if norm_ratio is None else round(norm_ratio * 100.0) # shown by UI | |
| log.info( | |
| "[RISK] lead %d/%d | object_id=%s | title=%s | raw_ratio=%.3f | raw_pct=%s | norm_ratio=%.3f | norm_pctβ%s%%", | |
| i, min(5, len(rows)), | |
| r.get("object_id"), | |
| (r.get("title") or "")[:80], | |
| raw_ratio if raw_ratio is not None else -1.0, | |
| f"{raw_pct:.0f}" if raw_pct is not None else "NA", | |
| norm_ratio if norm_ratio is not None else -1.0, | |
| f"{norm_pct:.0f}" if norm_pct is not None else "NA", | |
| ) | |
| resp = jsonify({"ok": True, "data": rows}) | |
| resp.headers["Cache-Control"] = "no-store, max-age=0" | |
| return resp | |
| def object_detail(object_id: int): | |
| with cursor() as cur: | |
| cur.execute("SELECT *, image_url FROM objects WHERE object_id=%s", (object_id,)) | |
| obj = cur.fetchone() | |
| if not obj: | |
| return jsonify({"ok": False, "error": "not_found"}), 404 | |
| # --- Normalize + overwrite the field the client reads (0..1) ----------- | |
| raw_ratio = _to_float(obj.get("risk_score")) # e.g., 2.0 = 200% | |
| norm_ratio = normalize_risk(raw_ratio) if raw_ratio is not None else None # 0..1 | |
| norm_0_99 = None if norm_ratio is None else round(norm_ratio * 100.0, 2) # reference | |
| obj["risk_score_raw"] = raw_ratio | |
| obj["risk_score_norm_0_99"] = norm_0_99 | |
| obj["risk_score"] = norm_ratio # what the UI already reads | |
| obj["risk_score_normalized"]= norm_ratio # alias | |
| # --- Log one line per object fetch (visible on HF console) ------------- | |
| log.info( | |
| "[RISK] /api/object | object_id=%s | raw_ratio=%s | raw_pct=%s | norm_ratio=%s | norm_pctβ%s%%", | |
| object_id, | |
| f"{raw_ratio:.3f}" if raw_ratio is not None else "NA", | |
| f"{raw_ratio*100:.0f}" if raw_ratio is not None else "NA", | |
| f"{norm_ratio:.3f}" if norm_ratio is not None else "NA", | |
| f"{norm_ratio*100:.0f}" if norm_ratio is not None else "NA", | |
| ) | |
| # ----------------------------------------------------------------------- | |
| cur.execute("SELECT seq, sentence FROM provenance_sentences WHERE object_id=%s ORDER BY seq", (object_id,)) | |
| sents = cur.fetchall() | |
| cur.execute("""SELECT event_type, date_from, date_to, place, actor, method, source_ref | |
| FROM provenance_events WHERE object_id=%s | |
| ORDER BY COALESCE(date_from,'0001-01-01')""", (object_id,)) | |
| events = cur.fetchall() | |
| cur.execute("SELECT code, detail, weight FROM risk_signals WHERE object_id=%s ORDER BY weight DESC", (object_id,)) | |
| risks = cur.fetchall() | |
| resp = jsonify({"ok": True, "object": obj, "sentences": sents, "events": events, "risks": risks}) | |
| resp.headers["Cache-Control"] = "no-store, max-age=0" | |
| return resp | |
| def graph(object_id: int): | |
| with cursor() as cur: | |
| cur.execute("SELECT object_id, source, title FROM objects WHERE object_id=%s", (object_id,)) | |
| obj = cur.fetchone() | |
| if not obj: | |
| return jsonify({"ok": False, "error": "not_found"}), 404 | |
| cur.execute("""SELECT event_type, date_from, date_to, place, actor, source_ref | |
| FROM provenance_events WHERE object_id=%s | |
| ORDER BY COALESCE(date_from,'0001-01-01')""", (object_id,)) | |
| events = cur.fetchall() | |
| cur.execute("SELECT seq, sentence FROM provenance_sentences WHERE object_id=%s ORDER BY seq", (object_id,)) | |
| sents = cur.fetchall() | |
| inferred = infer_events_from_sentences(sents) | |
| # Prefer stored events; fill with inferred where stored is thin | |
| merged = list(events) | |
| if not merged or all((not e.get("actor") and not e.get("place")) for e in merged): | |
| merged = inferred | |
| else: | |
| # add inferred items that add missing actor/place for the same year | |
| have = {(e.get("actor"), e.get("place"), e.get("event_type"), to_iso(e.get("date_from"))): True for e in merged} | |
| for e in inferred: | |
| key = (e.get("actor"), e.get("place"), e.get("event_type"), to_iso(e.get("date_from"))) | |
| if key not in have: | |
| merged.append(e) | |
| g = build_graph_from_events(obj, merged) | |
| # NEW: link successive actors to show chain of custody | |
| actors_in_time = [ (to_iso(e.get("date_from")) or "0001-01-01", e.get("actor")) for e in merged if e.get("actor") ] | |
| actors_in_time.sort(key=lambda x: x[0]) | |
| for i in range(len(actors_in_time) - 1): | |
| a1 = actors_in_time[i][1]; a2 = actors_in_time[i+1][1] | |
| if a1 and a2 and a1 != a2: | |
| g["edges"].append({ | |
| "source": f"actor:{a1}", | |
| "target": f"actor:{a2}", | |
| "label": "TRANSFER", | |
| "date": actors_in_time[i+1][0], | |
| "weight": 0.8, | |
| "policy": _policy_hits_for_date(actors_in_time[i+1][0]), | |
| "source_ref": "link:sequence" | |
| }) | |
| return jsonify({"ok": True, **g}) | |
| def places(object_id: int): | |
| with cursor() as cur: | |
| cur.execute("""SELECT place, date_from FROM provenance_events WHERE object_id=%s""", (object_id,)) | |
| ev = cur.fetchall() | |
| cur.execute("SELECT seq, sentence FROM provenance_sentences WHERE object_id=%s ORDER BY seq", (object_id,)) | |
| sents = cur.fetchall() | |
| inferred = infer_events_from_sentences(sents) | |
| all_places = [] | |
| for e in ev + inferred: | |
| p = _clean(e.get("place")) | |
| if p: | |
| all_places.append({"place": p, "date": to_iso(e.get("date_from"))}) | |
| # unique by place, keep earliest date | |
| agg = {} | |
| for r in all_places: | |
| d = r["date"] or "9999-12-31" | |
| if r["place"] not in agg or d < (agg[r["place"]].get("date") or "9999-12-31"): | |
| agg[r["place"]] = r | |
| out = [] | |
| for p, info in agg.items(): | |
| geo = geocode_place_cached(p) # may be None if geocoding blocked | |
| out.append({"place": p, "date": info.get("date"), "lat": (geo or {}).get("lat"), "lon": (geo or {}).get("lon")}) | |
| # order chronologically for path drawing | |
| out.sort(key=lambda x: x.get("date") or "9999-12-31") | |
| return jsonify({"ok": True, "places": out}) | |
| def timeline(object_id: int): | |
| with cursor() as cur: | |
| cur.execute("SELECT seq, sentence FROM provenance_sentences WHERE object_id=%s ORDER BY seq", (object_id,)) | |
| sents = cur.fetchall() | |
| cur.execute("""SELECT event_type, date_from, date_to, place, actor, source_ref | |
| FROM provenance_events WHERE object_id=%s | |
| ORDER BY COALESCE(date_from,'0001-01-01')""", (object_id,)) | |
| events = cur.fetchall() | |
| items = build_timeline_from_events_and_sentences(events, sents) | |
| return jsonify({"ok": True, "items": items}) | |
| def keyword_search(): | |
| q = (request.args.get("q") or "").strip() | |
| limit = max(1, min(int(request.args.get("limit", 50)), 200)) | |
| if not q: | |
| return jsonify({"ok": False, "error": "q required"}), 400 | |
| like = "%" + q.replace("%","").replace("_","") + "%" | |
| with cursor() as cur: | |
| cur.execute( | |
| """SELECT ps.object_id, ps.seq, ps.sentence, o.source, o.title, o.creator | |
| FROM provenance_sentences ps | |
| JOIN objects o ON o.object_id = ps.object_id | |
| WHERE ps.sentence LIKE %s | |
| LIMIT %s""", (like, limit) | |
| ) | |
| rows = cur.fetchall() | |
| return jsonify({"ok": True, "query": q, "data": rows}) | |
| def similar_search(): | |
| payload = request.get_json(force=True) or {} | |
| text = (payload.get("text") or "").strip() | |
| limit = max(1, min(int(payload.get("limit", 20)), 100)) | |
| candidates = int(payload.get("candidates", max(200, limit * 10))) # pre-topK by sentences | |
| source_filter = (payload.get("source") or "").strip().upper() # e.g., "AIC" | |
| if not text: | |
| return jsonify({"ok": False, "error": "text required"}), 400 | |
| # Embed (existing logic) | |
| try: | |
| import torch | |
| vec_t = _load_model().encode([text], batch_size=1, show_progress_bar=False, convert_to_tensor=True) | |
| vec = (vec_t[0].detach().cpu().tolist() if isinstance(vec_t, torch.Tensor) else list(vec_t[0])) | |
| except Exception as e: | |
| return jsonify({"ok": False, "error": f"embedding_unavailable: {e}"}), 503 | |
| vec_json = json.dumps(_pad(vec, VEC_DIM)) | |
| where_src = "WHERE o.source = %s" if source_filter else "" | |
| # --- IMPORTANT: dedupe by object_id using window function ----------------- | |
| # We pull top 'candidates' sentences, join to objects (apply optional source), | |
| # then keep only ROW_NUMBER() = 1 per object_id (best/closest sentence). | |
| sql = f""" | |
| WITH nn AS ( | |
| SELECT /*+ USE_INDEX(ps, hnsw_vec) */ | |
| ps.sent_id, ps.object_id, ps.seq, ps.sentence, | |
| VEC_COSINE_DISTANCE(ps.embedding, CAST(%s AS VECTOR({VEC_DIM}))) AS distance | |
| FROM provenance_sentences ps | |
| ORDER BY distance | |
| LIMIT %s | |
| ), | |
| ranked AS ( | |
| SELECT | |
| nn.object_id, | |
| nn.seq, | |
| nn.sentence, | |
| nn.distance, | |
| o.source, | |
| o.title, | |
| o.creator, | |
| ROW_NUMBER() OVER (PARTITION BY nn.object_id ORDER BY nn.distance ASC) AS rk | |
| FROM nn | |
| JOIN objects o ON o.object_id = nn.object_id | |
| {where_src} | |
| ) | |
| SELECT object_id, seq, sentence, source, title, creator, distance | |
| FROM ranked | |
| WHERE rk = 1 | |
| ORDER BY distance | |
| LIMIT %s | |
| """ | |
| params = [vec_json, candidates] | |
| if source_filter: | |
| params.append(source_filter) | |
| params.append(limit) | |
| try: | |
| with cursor() as cur: | |
| cur.execute(sql, params) | |
| rows = cur.fetchall() | |
| return jsonify({ | |
| "ok": True, | |
| "device": _DEVICE_INFO, | |
| "query": text, | |
| "data": rows, | |
| "meta": {"limit": limit, "candidates": candidates, "source": source_filter or None} | |
| }) | |
| except OperationalError as e: | |
| # TiDB OOM (1105) β retry with smaller candidate set | |
| if e.args and e.args[0] == 1105 and candidates > max(100, limit * 4): | |
| smaller = max(100, limit * 4) | |
| params2 = [vec_json, smaller] | |
| if source_filter: | |
| params2.append(source_filter) | |
| params2.append(limit) | |
| try: | |
| with cursor() as cur: | |
| cur.execute(sql, params2) | |
| rows = cur.fetchall() | |
| return jsonify({ | |
| "ok": True, | |
| "device": _DEVICE_INFO, | |
| "query": text, | |
| "data": rows, | |
| "meta": {"limit": limit, "candidates": smaller, "source": source_filter or None, | |
| "note": "retried with smaller candidate set"} | |
| }) | |
| except Exception as e2: | |
| return jsonify({"ok": False, "error": f"oom_retry_failed: {e2}"}), 500 | |
| # Not OOM or still failed β fall back to Python-side dedupe below | |
| # (This keeps you resilient if window functions act up.) | |
| try: | |
| # Simple fallback: same as your original query, dedupe in Python. | |
| where_src2 = "WHERE o.source = %s" if source_filter else "" | |
| sql2 = f""" | |
| WITH nn AS ( | |
| SELECT ps.sent_id, ps.object_id, ps.seq, ps.sentence, | |
| VEC_COSINE_DISTANCE(ps.embedding, CAST(%s AS VECTOR({VEC_DIM}))) AS distance | |
| FROM provenance_sentences ps | |
| ORDER BY distance | |
| LIMIT %s | |
| ) | |
| SELECT nn.object_id, nn.seq, nn.sentence, o.source, o.title, o.creator, nn.distance | |
| FROM nn | |
| JOIN objects o ON o.object_id = nn.object_id | |
| {where_src2} | |
| ORDER BY nn.distance | |
| LIMIT %s | |
| """ | |
| params2 = [vec_json, candidates] | |
| if source_filter: | |
| params2.append(source_filter) | |
| params2.append(limit * 5) # grab extra to allow dedupe | |
| with cursor() as cur: | |
| cur.execute(sql2, params2) | |
| many = cur.fetchall() | |
| # Python dedupe: keep first (closest) row per object_id | |
| seen = set() | |
| out = [] | |
| for r in many: | |
| oid = r.get("object_id") | |
| if oid in seen: | |
| continue | |
| seen.add(oid) | |
| out.append(r) | |
| if len(out) >= limit: | |
| break | |
| return jsonify({ | |
| "ok": True, | |
| "device": _DEVICE_INFO, | |
| "query": text, | |
| "data": out, | |
| "meta": {"limit": limit, "candidates": candidates, "source": source_filter or None, | |
| "note": "python-dedup fallback"} | |
| }) | |
| except Exception as e3: | |
| return jsonify({"ok": False, "error": f"query_failed: {e} (fallback: {e3})"}), 500 | |
| def vocab(): | |
| field = (request.args.get("field") or "").strip().lower() | |
| limit = max(1, min(int(request.args.get("limit", 100)), 500)) | |
| if field not in {"actor", "place", "source", "culture"}: | |
| return jsonify({"ok": False, "error": "field must be one of actor|place|source|culture"}), 400 | |
| if field in {"actor", "place"}: | |
| sql = f"SELECT {field} AS v, COUNT(*) AS n FROM provenance_events WHERE {field} IS NOT NULL AND {field}<>'' GROUP BY {field} ORDER BY n DESC LIMIT %s" | |
| elif field == "source": | |
| sql = "SELECT source AS v, COUNT(*) AS n FROM objects GROUP BY source ORDER BY n DESC LIMIT %s" | |
| else: # culture | |
| sql = "SELECT culture AS v, COUNT(*) AS n FROM objects WHERE culture IS NOT NULL AND culture<>'' GROUP BY culture ORDER BY n DESC LIMIT %s" | |
| with cursor() as cur: | |
| cur.execute(sql, (limit,)) | |
| rows = cur.fetchall() | |
| return jsonify({"ok": True, "field": field, "data": rows}) | |
| # ββ Gemini-powered explanations ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def explain_object(object_id: int): | |
| """Generate a concise, policy-aware research note for an object.""" | |
| with cursor() as cur: | |
| cur.execute("SELECT object_id, source, title, creator, date_display, risk_score FROM objects WHERE object_id=%s", (object_id,)) | |
| obj = cur.fetchone() | |
| if not obj: | |
| return jsonify({"ok": False, "error": "not_found"}), 404 | |
| cur.execute("SELECT seq, sentence FROM provenance_sentences WHERE object_id=%s ORDER BY seq", (object_id,)) | |
| sents = cur.fetchall() | |
| cur.execute("SELECT event_type, date_from, date_to, place, actor, source_ref FROM provenance_events WHERE object_id=%s ORDER BY COALESCE(date_from,'0001-01-01')", (object_id,)) | |
| events = cur.fetchall() | |
| # Build a compact prompt (few sentences) to keep latency low | |
| bullets = [] | |
| for s in sents[:8]: # keep prompt small | |
| bullets.append(f"- {s['sentence']}") | |
| evsumm = [] | |
| for e in events[:8]: | |
| evsumm.append(f"{e.get('event_type')} @ {e.get('place') or 'β'} on {e.get('date_from') or 'β'} (actor: {e.get('actor') or 'β'})") | |
| sys = ("You are assisting provenance researchers. Write a neutral, concise brief (120β180 words) that:\n" | |
| "1) summarizes the chain of custody in plain language; 2) clearly marks any timeline gaps; " | |
| "3) calls out potential red flags (e.g., confiscated/looted, sales during 1933β45, exports post-1970) " | |
| "without making legal conclusions; 4) ends with a short 'Next leads' list (max 3).") | |
| prompt = ( | |
| f"Object: {obj.get('title') or 'Untitled'} β {obj.get('creator') or ''} (source {obj['source']}). " | |
| f"Display date: {obj.get('date_display') or 'n/a'}. Current risk_score={obj.get('risk_score', 0)}.\n\n" | |
| f"Provenance sentences:\n" + "\n".join(bullets) + "\n\n" | |
| f"Structured events (first 8):\n- " + "\n- ".join(evsumm) + "\n\n" | |
| f"Policy windows to consider: Nazi era 1933β1945; UNESCO 1970 onwards." | |
| ) | |
| text = gemini_explain(prompt, sys=sys) | |
| return jsonify({"ok": True, "model": EXPLAIN_MODEL, "note": text}) | |
| def explain_text(): | |
| """Explain a specific provenance sentence or user query with policy context.""" | |
| payload = request.get_json(force=True) or {} | |
| sentence = (payload.get("text") or "").strip() | |
| if not sentence: | |
| return jsonify({"ok": False, "error": "text required"}), 400 | |
| sys = ("Explain this text as a provenance note for curators. " | |
| "Be precise and cautious; highlight possible red flags tied to 1933β1945 and post-1970 export rules.") | |
| prompt = f"""Explain and contextualize this provenance fragment:\n\n{sentence}.""" | |
| text = gemini_explain(prompt, sys=sys) | |
| return jsonify({"ok": True, "model": EXPLAIN_MODEL, "explanation": text}) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # MAIN (Spaces expects 7860) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if __name__ == "__main__": | |
| port = int(os.environ.get("PORT", "7860")) | |
| app.run(host="0.0.0.0", port=port, debug=False) | |