rairo commited on
Commit
b24840f
Β·
verified Β·
1 Parent(s): 96ba7ab

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +972 -208
main.py CHANGED
@@ -1,251 +1,1015 @@
1
- import os
2
- import io
3
- import logging
4
- import re
5
- import pandas as pd
6
- import pdfplumber
7
  from flask import Flask, request, jsonify
8
  from flask_cors import CORS
9
- from flask_sqlalchemy import SQLAlchemy
10
- from sqlalchemy.exc import IntegrityError
11
- from thefuzz import process, fuzz
12
- from werkzeug.utils import secure_filename
 
 
 
 
13
 
 
 
 
14
  # ───────────────────────────────────────────────────────────────────────────────
15
- # CONFIGURATION
16
  # ───────────────────────────────────────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
 
 
 
18
  logging.basicConfig(level=logging.INFO)
19
- log = logging.getLogger("product-pipeline-api")
20
 
21
  app = Flask(__name__)
22
  CORS(app)
23
 
24
- # --- App Configuration ---
25
- app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite:///:memory:'
26
- app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False
27
- app.config['UPLOAD_FOLDER'] = 'uploads'
28
- os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
29
-
30
- # --- File Upload Configuration ---
31
- ALLOWED_EXTENSIONS = {'csv', 'xls', 'xlsx'}
32
-
33
- db = SQLAlchemy(app)
34
-
35
  # ───────────────────────────────────────────────────────────────────────────────
36
- # DATABASE MODEL (Based on products-20.sql)
37
  # ───────────────────────────────────────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- class Product(db.Model):
40
- """Represents the 'products' table."""
41
- __tablename__ = 'products'
42
- id = db.Column(db.Integer, primary_key=True)
43
- name = db.Column(db.String(255), nullable=False, unique=True)
44
- category_id = db.Column(db.Integer, nullable=False, default=1)
45
- primary_category = db.Column(db.String(255), nullable=False, default='N/A')
46
- hs_code = db.Column(db.String(255), nullable=True)
47
-
48
- def to_dict(self):
49
- """Serializes the Product object to a dictionary."""
50
- return {
51
- 'id': self.id,
52
- 'name': self.name,
53
- 'category_id': self.category_id,
54
- 'primary_category': self.primary_category,
55
- 'hs_code': self.hs_code
56
- }
57
-
58
- def __repr__(self):
59
- return f'<Product {self.id}: {self.name}>'
60
 
61
  # ───────────────────────────────────────────────────────────────────────────────
62
- # DATA LOADING & PRE-PROCESSING
63
  # ───────────────────────────────────────────────────────────────────────────────
 
 
64
 
65
- FUZZY_MATCH_THRESHOLD = 85
66
- HS_CODES_DATA = []
67
- EXISTING_PRODUCT_NAMES = []
68
- HS_CODE_DESCRIPTIONS = {}
69
 
70
- def parse_hs_codes_pdf(filepath='HS Codes for use under FDMS.pdf'):
71
- log.info(f"Parsing HS Codes from '{filepath}'...")
72
- if not os.path.exists(filepath):
73
- log.error(f"HS Code PDF not found at '{filepath}'. Categorization will fail.")
74
- return []
75
- codes = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  try:
77
- with pdfplumber.open(filepath) as pdf:
78
- for page in pdf.pages:
79
- text = page.extract_text()
80
- matches = re.findall(r'\"(\d+)\n\"\,?\"(.*?)\n\"', text, re.DOTALL)
81
- for code, desc in matches:
82
- clean_desc = desc.replace('\n', ' ').strip()
83
- if code and clean_desc:
84
- codes.append({'code': code, 'description': clean_desc})
85
- HS_CODE_DESCRIPTIONS[clean_desc] = code
86
- except Exception as e:
87
- log.error(f"Failed to parse PDF: {e}")
88
- log.info(f"Successfully parsed {len(codes)} HS codes.")
89
- return codes
90
-
91
- def load_existing_products(filepath='Product List.csv'):
92
- log.info(f"Loading master product list from '{filepath}'...")
93
- if not os.path.exists(filepath):
94
- log.error(f"Master product list not found at '{filepath}'. Validation may be inaccurate.")
95
- return []
96
  try:
97
- df = pd.read_csv(filepath)
98
- product_names = df['name'].dropna().unique().tolist()
99
- log.info(f"Loaded {len(product_names)} unique existing products.")
100
- return product_names
101
  except Exception as e:
102
- log.error(f"Failed to load master product list: {e}")
103
- return []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
  # ───────────────────────────────────────────────────────────────────────────────
106
- # CORE PROCESSING PIPELINE
107
  # ───────────────────────────────────────────────────────────────────────────────
108
-
109
- def process_uploaded_file(filepath, filename):
110
- """The main pipeline to validate, clean, categorize, and store product data."""
111
- log.info(f"Starting processing for file: {filepath}")
112
- results = {
113
- "processed": 0, "added": 0, "updated": 0, "skipped_duplicates": 0,
114
- "errors": [], "processed_data": []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  }
116
- df = None
117
 
118
- try:
119
- # --- Read file based on extension ---
120
- file_ext = filename.rsplit('.', 1)[1].lower()
121
- if file_ext == 'csv':
122
- df = pd.read_csv(filepath, header=None)
123
- elif file_ext in ['xls', 'xlsx']:
124
- # engine='openpyxl' is needed for .xlsx files
125
- df = pd.read_excel(filepath, header=None, engine='openpyxl')
126
- except Exception as e:
127
- log.error(f"Could not read the uploaded file: {e}")
128
- results['errors'].append(f"Invalid file format or corrupt file: {e}")
129
- return results
130
-
131
- if df.empty:
132
- results['errors'].append("The uploaded file is empty.")
133
- return results
134
-
135
- # Heuristically find the column with product names
136
- product_name_col = None
137
- for col in df.columns:
138
- if df[col].dtype == 'object' and df[col].astype(str).str.contains('[a-zA-Z]').any():
139
- product_name_col = col
140
- break
141
-
142
- if product_name_col is None:
143
- results['errors'].append("Could not find a column with product names in the uploaded file.")
144
- return results
145
-
146
- for index, row in df.iterrows():
147
- raw_name = row[product_name_col]
148
- results['processed'] += 1
149
-
150
- if not isinstance(raw_name, str) or not raw_name.strip():
151
  continue
152
-
153
- # --- 1. Validation & Cleaning ---
154
- best_match, score = process.extractOne(
155
- raw_name, EXISTING_PRODUCT_NAMES, scorer=fuzz.token_sort_ratio
156
- ) if EXISTING_PRODUCT_NAMES else (raw_name, 100)
157
- cleaned_name = best_match if score >= FUZZY_MATCH_THRESHOLD else raw_name
158
-
159
- # --- 2. HS Code Categorization ---
160
- best_hs_desc, _ = process.extractOne(
161
- cleaned_name, HS_CODE_DESCRIPTIONS.keys()
162
- ) if HS_CODE_DESCRIPTIONS else (None, 0)
163
- hs_code = HS_CODE_DESCRIPTIONS.get(best_hs_desc)
164
-
165
- # --- 3. Database Operation ---
166
- processed_entry = {
167
- "raw_name": raw_name, "cleaned_name": cleaned_name, "hs_code": hs_code,
168
- "primary_category": best_hs_desc or "N/A", "status": ""
169
- }
170
- try:
171
- existing_product = Product.query.filter_by(name=cleaned_name).first()
172
- if existing_product:
173
- if hs_code and existing_product.hs_code != hs_code:
174
- existing_product.hs_code = hs_code
175
- existing_product.primary_category = best_hs_desc
176
- db.session.commit()
177
- results['updated'] += 1
178
- processed_entry['status'] = 'Updated'
179
- else:
180
- results['skipped_duplicates'] += 1
181
- processed_entry['status'] = 'Skipped (Duplicate)'
182
- else:
183
- new_product = Product(name=cleaned_name, hs_code=hs_code, primary_category=best_hs_desc or 'N/A')
184
- db.session.add(new_product)
185
- db.session.commit()
186
- results['added'] += 1
187
- processed_entry['status'] = 'Added'
188
- results['processed_data'].append(processed_entry)
189
- except Exception as e:
190
- db.session.rollback()
191
- log.error(f"Database error for '{cleaned_name}': {e}")
192
- results['errors'].append(f"DB Error on '{cleaned_name}': {e}")
193
-
194
- return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
 
196
  # ───────────────────────────────────────────────────────────────────────────────
197
  # ROUTES
198
  # ───────────────────────────────────────────────────────────────────────────────
199
 
200
- def allowed_file(filename):
201
- """Checks if the file's extension is allowed."""
202
- return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
203
-
204
  @app.get("/")
205
  def root():
206
- return jsonify({"ok": True, "message": "The Product Validation server is running."})
207
-
208
- @app.post("/api/upload")
209
- def upload_products():
210
- """Endpoint to upload and process a product CSV or Excel file."""
211
- if 'file' not in request.files:
212
- return jsonify({"ok": False, "error": "No file part in the request"}), 400
213
- file = request.files['file']
214
- if file.filename == '':
215
- return jsonify({"ok": False, "error": "No file selected"}), 400
216
-
217
- if file and allowed_file(file.filename):
218
- filename = secure_filename(file.filename)
219
- filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
220
- file.save(filepath)
221
- results = process_uploaded_file(filepath, filename)
222
- return jsonify({"ok": True, "message": "File processed successfully", "results": results})
223
-
224
- return jsonify({"ok": False, "error": f"Invalid file type. Allowed types are: {', '.join(ALLOWED_EXTENSIONS)}"}), 400
225
-
226
- @app.get("/api/products")
227
- def get_products():
228
- """Endpoint to retrieve all processed products from the database."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  try:
230
- all_products = Product.query.all()
231
- products_list = [product.to_dict() for product in all_products]
232
- return jsonify({"ok": True, "count": len(products_list), "products": products_list})
233
  except Exception as e:
234
- log.error(f"Could not retrieve products from database: {e}")
235
- return jsonify({"ok": False, "error": "Failed to retrieve products from the database."}), 500
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
 
237
  # ──────────────────────────────────────────────────────────���────────────────────
238
- # MAIN (Server Initialization)
239
  # ───────────────────────────────────────────────────────────────────────────────
240
-
241
  if __name__ == "__main__":
242
- with app.app_context():
243
- log.info("Initializing server...")
244
- db.create_all()
245
- HS_CODES_DATA = parse_hs_codes_pdf()
246
- EXISTING_PRODUCT_NAMES = load_existing_products()
247
- log.info("Server is ready and validation data is loaded.")
248
-
249
  port = int(os.environ.get("PORT", "7860"))
250
  app.run(host="0.0.0.0", port=port, debug=False)
251
-
 
1
+ import os, json, logging, warnings, time, certifi, pymysql, requests
2
+ from contextlib import contextmanager
3
+ from datetime import date
 
 
 
4
  from flask import Flask, request, jsonify
5
  from flask_cors import CORS
6
+ from datetime import date, datetime
7
+ # ---- Optional Google GenAI (Gemini) ----
8
+ from google import genai
9
+ from google.genai import types
10
+
11
+ from pymysql.err import OperationalError
12
+ import threading
13
+ warnings.filterwarnings("ignore")
14
 
15
+ # ── NEW: lightweight event inference from sentences ───────────────────────────
16
+ import re
17
+ from typing import List, Dict, Any, Optional
18
  # ───────────────────────────────────────────────────────────────────────────────
19
+ # CONFIG
20
  # ───────────────────────────────────────────────────────────────────────────────
21
+ DB_NAME = os.getenv("TIDB_DB")
22
+ TIDB_HOST = os.getenv("TIDB_HOST")
23
+ TIDB_PORT = int(os.getenv("TIDB_PORT"))
24
+ TIDB_USER = os.getenv("TIDB_USER")
25
+ TIDB_PASS = os.getenv("TIDB_PASS")
26
+
27
+ VEC_DIM = int(os.getenv("VEC_DIM", "1536"))
28
+ EMBED_MODEL = os.getenv("EMBED_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
29
+ USE_GPU = os.getenv("USE_GPU", "0") == "1" # Spaces are usually CPU; works either way
30
+
31
+ # Policy windows (server is single source of truth for the client)
32
+ POLICY_WINDOWS = [
33
+ {
34
+ "code": "NAZI_ERA",
35
+ "label": "Washington Conference Principles (1933–1945)",
36
+ "from": "1933-01-01",
37
+ "to": "1945-12-31",
38
+ "ref": "https://www.state.gov/washington-conference-principles-on-nazi-confiscated-art"
39
+ },
40
+ {
41
+ "code": "UNESCO_1970",
42
+ "label": "UNESCO 1970 Convention",
43
+ "from": "1970-11-14",
44
+ "to": None,
45
+ "ref": "https://www.unesco.org/en/legal-affairs/convention-means-prohibiting-and-preventing-illicit-import-export-and-transfer-ownership-cultural"
46
+ }
47
+ ]
48
 
49
+ # ───────────────────────────────────────────────────────────────────────────────
50
+ # APP + LOGGING
51
+ # ───────────────────────────────────────────────────────────────────────────────
52
  logging.basicConfig(level=logging.INFO)
53
+ log = logging.getLogger("provenance-api")
54
 
55
  app = Flask(__name__)
56
  CORS(app)
57
 
 
 
 
 
 
 
 
 
 
 
 
58
  # ───────────────────────────────────────────────────────────────────────────────
59
+ # DB CONNECTION (refactored for better connection management)
60
  # ───────────────────────────────────────────────────────────────────────────────
61
+ _connection_lock = threading.Lock()
62
+
63
+ def _create_connection():
64
+ """Create a new database connection with optimized settings"""
65
+ return pymysql.connect(
66
+ host=TIDB_HOST,
67
+ port=TIDB_PORT,
68
+ user=TIDB_USER,
69
+ password=TIDB_PASS,
70
+ database=DB_NAME,
71
+ ssl={"ca": certifi.where()},
72
+ ssl_verify_cert=True,
73
+ ssl_verify_identity=True,
74
+ autocommit=True,
75
+ charset="utf8mb4",
76
+ cursorclass=pymysql.cursors.DictCursor,
77
+ connect_timeout=10,
78
+ read_timeout=60, # Increased for vector operations
79
+ write_timeout=30,
80
+ # TiDB-specific optimizations:
81
+ init_command="SET SESSION sql_mode='STRICT_TRANS_TABLES,NO_ZERO_DATE,NO_ZERO_IN_DATE,ERROR_FOR_DIVISION_BY_ZERO'",
82
+ client_flag=pymysql.constants.CLIENT.MULTI_STATEMENTS,
83
+ )
84
+
85
+ @contextmanager
86
+ def cursor():
87
+ """Create a fresh connection for each request context with retry logic"""
88
+ conn = None
89
+ max_retries = 3
90
+
91
+ for attempt in range(max_retries):
92
+ try:
93
+ conn = _create_connection()
94
+ with conn.cursor() as cur:
95
+ yield cur
96
+ break
97
+ except (OperationalError, pymysql.err.InternalError) as e:
98
+ if conn:
99
+ try:
100
+ conn.close()
101
+ except Exception:
102
+ pass
103
+ conn = None
104
+
105
+ if attempt == max_retries - 1:
106
+ log.error(f"Database connection failed after {max_retries} attempts: {e}")
107
+ raise
108
+ else:
109
+ log.warning(f"Database connection failed (attempt {attempt + 1}): {e}")
110
+ time.sleep(0.5 * (attempt + 1)) # Exponential backoff
111
+ except Exception as e:
112
+ if conn:
113
+ try:
114
+ conn.close()
115
+ except Exception:
116
+ pass
117
+ log.error(f"Database connection failed: {e}")
118
+ raise
119
+ finally:
120
+ if conn:
121
+ try:
122
+ conn.close()
123
+ except Exception:
124
+ pass
125
+
126
+ def with_db_retry(func):
127
+ """Decorator to retry database operations on connection failures"""
128
+ import functools
129
+ @functools.wraps(func) # This preserves the original function's metadata
130
+ def wrapper(*args, **kwargs):
131
+ max_retries = 3
132
+ for attempt in range(max_retries):
133
+ try:
134
+ return func(*args, **kwargs)
135
+ except (OperationalError, pymysql.err.InternalError) as e:
136
+ if attempt == max_retries - 1:
137
+ log.error(f"Database operation failed after {max_retries} attempts: {e}")
138
+ raise
139
+ log.warning(f"Database operation failed (attempt {attempt + 1}): {e}")
140
+ time.sleep(0.5 * (attempt + 1))
141
+ return wrapper
142
 
143
+ # ───────────────────────────────────────────────────────────────────────────────
144
+ # ERROR HANDLERS
145
+ # ───────────────────────────────────────────────────────────────────────────────
146
+ @app.errorhandler(OperationalError)
147
+ def handle_db_error(e):
148
+ log.error(f"Database error: {e}")
149
+ return jsonify({
150
+ "ok": False,
151
+ "error": "database_unavailable",
152
+ "message": "Database connection issue. Please try again."
153
+ }), 503
154
+
155
+ @app.errorhandler(pymysql.err.InternalError)
156
+ def handle_internal_error(e):
157
+ log.error(f"Database internal error: {e}")
158
+ return jsonify({
159
+ "ok": False,
160
+ "error": "database_error",
161
+ "message": "Database operation failed. Please try again."
162
+ }), 500
 
163
 
164
  # ───────────────────────────────────────────────────────────────────────────────
165
+ # EMBEDDINGS (lazy-load; same model as ingest; pad to 1536)
166
  # ───────────────────────────────────────────────────────────────────────────────
167
+ _MODEL = None
168
+ _DEVICE_INFO = "cpu"
169
 
170
+ def _pad(vec, dim=VEC_DIM):
171
+ return vec[:dim] + [0.0] * max(0, dim - len(vec))
 
 
172
 
173
+ def _load_model():
174
+ global _MODEL, _DEVICE_INFO
175
+ if _MODEL is not None:
176
+ return _MODEL
177
+ if USE_GPU:
178
+ try:
179
+ import torch
180
+ if torch.cuda.is_available():
181
+ _DEVICE_INFO = "cuda"
182
+ except Exception:
183
+ _DEVICE_INFO = "cpu"
184
+ from sentence_transformers import SentenceTransformer
185
+ _MODEL = SentenceTransformer(EMBED_MODEL, device=_DEVICE_INFO)
186
+ log.info(f"Loaded embedding model on '{_DEVICE_INFO}': {EMBED_MODEL}")
187
+ return _MODEL
188
+
189
+
190
+ def embed_text_to_vec1536(text: str):
191
+ model = _load_model()
192
+ # Use Torch tensors to avoid NumPy code path entirely
193
+ import torch
194
+ t = model.encode([text], batch_size=1, show_progress_bar=False, convert_to_tensor=True)
195
+ if isinstance(t, torch.Tensor):
196
+ vec = t[0].detach().cpu().tolist()
197
+ else:
198
+ # very defensive fallback
199
+ vec = list(t[0])
200
+ return _pad(vec, VEC_DIM)
201
+
202
+
203
+ def to_iso(d):
204
+ """Return YYYY-MM-DD for date/datetime/str; None for empty."""
205
+ if d is None:
206
+ return None
207
+ if isinstance(d, (date, datetime)):
208
+ return d.isoformat()[:10]
209
+ if isinstance(d, str):
210
+ return d[:10] if d else None
211
+ # fallback
212
  try:
213
+ return str(d)[:10]
214
+ except Exception:
215
+ return None
216
+ # ───────────────────────────────────────────────────────────────────────────────
217
+ # GEMINI (explanations / descriptions)
218
+ # ───────────────────────────────────────────────────────────────────────────────
219
+ GEMINI_KEY = os.environ.get("Gemini")
220
+ _gclient = None
221
+
222
+ def _gemini():
223
+ global _gclient
224
+ if _gclient is not None:
225
+ return _gclient
226
+ if not GEMINI_KEY:
227
+ return None
 
 
 
 
228
  try:
229
+ _gclient = genai.Client(api_key=GEMINI_KEY)
230
+ log.info("Gemini client initialized.")
231
+ return _gclient
 
232
  except Exception as e:
233
+ log.warning(f"Gemini init failed: {e}")
234
+ return None
235
+
236
+ EXPLAIN_MODEL = "gemini-2.0-flash"
237
+
238
+ def gemini_explain(prompt: str, sys: str = None, model: str = EXPLAIN_MODEL) -> str:
239
+ g = _gemini()
240
+ if g is None:
241
+ # Graceful fallback so the API still works without a key
242
+ return "(Gemini not configured) " + prompt[:180]
243
+ # chat-style to mirror your original pattern
244
+ chat = g.chats.create(model=model)
245
+ # Add a light system preamble for style/constraints
246
+ if sys:
247
+ chat.send_message(f"[SYSTEM]\n{sys}")
248
+ resp = chat.send_message(prompt)
249
+ return getattr(resp, "text", "").strip() or ""
250
 
251
  # ───────────────────────────────────────────────────────────────────────────────
252
+ # UTIL: Build risk scores, graph & timeline from events (+ risk overlays)
253
  # ───────────────────────────────────────────────────────────────────────────────
254
+ #
255
+ # Targets:
256
+ # raw 100 -> ~55
257
+ # raw 200 -> ~80
258
+ # raw 2000 -> ~99 (slow approach to 99 beyond this)
259
+ # BLOCK 1 β€” Helpers (drop-in)
260
+ # - Piecewise normalize_risk() curve
261
+ # - _to_float() coercion
262
+ # - _apply_normalized_risk_inplace(): overwrites 'risk_score' and keeps 'risk_score_raw'
263
+
264
+ import math
265
+ from decimal import Decimal
266
+
267
+ def _to_float(x):
268
+ if x is None: return None
269
+ if isinstance(x, (int, float)): return float(x)
270
+ if isinstance(x, Decimal): return float(x)
271
+ if isinstance(x, str):
272
+ try: return float(x.strip().replace("%",""))
273
+ except Exception: return None
274
+ try: return float(x)
275
+ except Exception: return None
276
+
277
+ def _piecewise_0_99_from_percent(pct: float) -> float:
278
+ """Piecewise curve on a 0–99 scale using 'percent' inputs (100, 200, ...)."""
279
+ x = max(float(pct), 0.0)
280
+ if x <= 100.0:
281
+ out = 55.0 * ((x / 100.0) ** 0.7) # ~55 at 100
282
+ elif x <= 200.0:
283
+ out = 55.0 + 25.0 * (((x - 100.0) / 100.0) ** 0.8) # 55β†’80 between 100–200
284
+ else:
285
+ k = math.log(100.0) / 1800.0 # ~98.8 at 2000
286
+ out = 99.0 - 19.0 * math.exp(-k * (x - 200.0))
287
+ return max(0.0, min(out, 99.0))
288
+
289
+ def normalize_risk(score_ratio: float) -> float:
290
+ """
291
+ INPUT: raw ratio (1.0=100%, 2.0=200%, 6.0=600%)
292
+ OUTPUT: normalized ratio on 0–1 scale (e.g., 0.8 for 80%)
293
+ """
294
+ r = _to_float(score_ratio)
295
+ if r is None: return None
296
+ pct_in = r * 100.0 # convert to percent domain for mapping
297
+ pct_out = _piecewise_0_99_from_percent(pct_in)
298
+ return round(pct_out / 100.0, 6) # send back as 0–1 for the UI
299
+
300
+
301
+ def _apply_normalized_risk_inplace(row: dict):
302
+ if not isinstance(row, dict):
303
+ return
304
+ raw_ratio = _to_float(row.get("risk_score"))
305
+ if raw_ratio is None:
306
+ return
307
+ norm_ratio = normalize_risk(raw_ratio) # 0–1
308
+ norm_0_99 = None if norm_ratio is None else round(norm_ratio * 100.0, 2)
309
+
310
+ row["risk_score_raw"] = raw_ratio # raw ratio (e.g., 2.0)
311
+ row["risk_score_norm_0_99"] = norm_0_99 # 0–99 reference (e.g., 80.0)
312
+ row["risk_score"] = norm_ratio # **what client already uses** (0–1)
313
+ row["risk_score_normalized"]= norm_ratio # alias if client checks this too
314
+
315
+
316
+ EVENT_VERBS = {
317
+ "sold": "SOLD",
318
+ "purchased": "PURCHASED",
319
+ "bought": "PURCHASED",
320
+ "acquired": "ACQUIRED",
321
+ "donated": "DONATED",
322
+ "gifted": "DONATED",
323
+ "bequeathed": "BEQUEATHED",
324
+ "consigned": "CONSIGNED",
325
+ "exhibited": "EXHIBITED",
326
+ "exported": "EXPORTED",
327
+ "imported": "IMPORTED",
328
+ }
329
+
330
+ YEAR_RE = re.compile(r"\b(1[6-9]\d{2}|20\d{2})\b") # 1600–2099
331
+
332
+ def _clean(s: Optional[str]) -> Optional[str]:
333
+ if not s: return None
334
+ s = re.sub(r"\s+", " ", s).strip(" ,.;:-–—")
335
+ return s or None
336
+
337
+ def _infer_from_sentence(txt: str) -> Optional[Dict[str, Any]]:
338
+ """
339
+ Very pragmatic patterns that cover most catalogue phrasing:
340
+ - 'sold to X, <place>, 2000'
341
+ - 'sold to X, by 2000'
342
+ - 'purchased from Y in 1965'
343
+ - 'donated by X, <place>, 1971'
344
+ Returns a dict compatible with provenance_events rows.
345
+ """
346
+ if not txt:
347
+ return None
348
+ low = txt.lower()
349
+
350
+ # find verb
351
+ verb = next((EVENT_VERBS[v] for v in EVENT_VERBS if v in low), None)
352
+ if not verb:
353
+ return None
354
+
355
+ # pull a year (prefers the last year in the string)
356
+ years = YEAR_RE.findall(txt)
357
+ year = years[-1] if years else None
358
+
359
+ actor = None
360
+ place = None
361
+
362
+ # Common pattern: 'sold to X, place, 2000'
363
+ m = re.search(r"\b(sold|purchased|bought|acquired|donated|gifted|bequeathed|consigned)\s+(to|by|from)\s+(.*)$", low)
364
+ if m:
365
+ # Take the fragment after 'to/by/from'
366
+ frag = txt[m.end(2)+1:].strip()
367
+ # Trim trailing year or 'by 2000'
368
+ frag = re.sub(r"(,\s*)?(by\s*)?\b(1[6-9]\d{2}|20\d{2})\b.*$", "", frag, flags=re.IGNORECASE).strip(" ,.;")
369
+ # Split on commas: first token is actor; the rest (if any) is place
370
+ parts = [p.strip() for p in re.split(r",(?![^()]*\))", frag) if p.strip()]
371
+ if parts:
372
+ actor = parts[0]
373
+ if len(parts) > 1:
374
+ place = ", ".join(parts[1:])
375
+
376
+ # Fallback simple 'sold to X' without commas
377
+ if not actor:
378
+ m2 = re.search(r"\bsold\s+to\s+([^,.;]+)", low)
379
+ if m2:
380
+ actor = _clean(txt[m2.start(1):m2.end(1)])
381
+
382
+ return {
383
+ "event_type": verb,
384
+ "date_from": f"{year}-01-01" if year else None,
385
+ "date_to": None,
386
+ "place": _clean(place),
387
+ "actor": _clean(actor),
388
+ "method": None,
389
+ "source_ref": "inferred:sentence"
390
  }
 
391
 
392
+ def infer_events_from_sentences(sentences: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
393
+ out: List[Dict[str, Any]] = []
394
+ for s in sentences:
395
+ ev = _infer_from_sentence(s.get("sentence", ""))
396
+ if ev and (ev.get("actor") or ev.get("place")):
397
+ ev["seq"] = s.get("seq")
398
+ out.append(ev)
399
+ # Deduplicate (actor+place+event_type+date_from)
400
+ seen = set()
401
+ uniq = []
402
+ for e in out:
403
+ key = (e.get("actor"), e.get("place"), e.get("event_type"), e.get("date_from"))
404
+ if key in seen:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
405
  continue
406
+ seen.add(key)
407
+ uniq.append(e)
408
+ return uniq
409
+
410
+ # ── OPTIONAL: simple geocode cache for map pins ───────────────────────────────
411
+ def geocode_place_cached(place: str):
412
+ """Cache in DB: places_cache(place TEXT PRIMARY KEY, lat DOUBLE, lon DOUBLE, updated_at TIMESTAMP)"""
413
+ if not place:
414
+ return None
415
+ with cursor() as cur:
416
+ cur.execute("CREATE TABLE IF NOT EXISTS places_cache (place VARCHAR(255) PRIMARY KEY, lat DOUBLE, lon DOUBLE, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP)")
417
+ cur.execute("SELECT lat, lon FROM places_cache WHERE place=%s", (place,))
418
+ row = cur.fetchone()
419
+ if row and row.get("lat") is not None and row.get("lon") is not None:
420
+ return row
421
+
422
+ # Try Nominatim (best effort). If outbound HTTP is blocked, just skip.
423
+ try:
424
+ r = requests.get(
425
+ "https://nominatim.openstreetmap.org/search",
426
+ params={"q": place, "format": "json", "limit": 1},
427
+ headers={"User-Agent": "provenance-radar/1.0"},
428
+ timeout=6,
429
+ )
430
+ j = r.json()
431
+ if j:
432
+ lat, lon = float(j[0]["lat"]), float(j[0]["lon"])
433
+ else:
434
+ lat, lon = None, None
435
+ except Exception:
436
+ lat, lon = None, None
437
+
438
+ with cursor() as cur:
439
+ cur.execute(
440
+ "INSERT INTO places_cache (place, lat, lon) VALUES (%s,%s,%s) ON DUPLICATE KEY UPDATE lat=VALUES(lat), lon=VALUES(lon), updated_at=CURRENT_TIMESTAMP",
441
+ (place, lat, lon),
442
+ )
443
+ if lat is None or lon is None:
444
+ return None
445
+ return {"lat": lat, "lon": lon}
446
+
447
+ def _policy_hits_for_date(d: str):
448
+ """Return policy codes a given ISO date falls into."""
449
+ if not d:
450
+ return []
451
+ hits = []
452
+ for w in POLICY_WINDOWS:
453
+ start_ok = (d >= w["from"]) if w["from"] else True
454
+ end_ok = (d <= w["to"]) if w["to"] else True
455
+ if start_ok and end_ok:
456
+ hits.append(w["code"])
457
+ return hits
458
+
459
+ def build_graph_from_events(obj_row, events):
460
+ """Cytoscape.js-style graph: nodes+edges."""
461
+ nodes = []
462
+ edges = []
463
+
464
+ # center object node
465
+ onode = {
466
+ "id": f"obj:{obj_row['object_id']}",
467
+ "label": f"{obj_row.get('title') or 'Untitled'} ({obj_row.get('source')})",
468
+ "type": "object"
469
+ }
470
+ nodes_map = {onode["id"]: onode}
471
+
472
+ def add_node(kind, label):
473
+ if not label:
474
+ return None
475
+ nid = f"{kind}:{label}"
476
+ if nid not in nodes_map:
477
+ nodes_map[nid] = {"id": nid, "label": label, "type": kind}
478
+ return nid
479
+
480
+ for ev in events:
481
+ actor = ev.get("actor")
482
+ place = ev.get("place")
483
+ etype = ev.get("event_type") or "UNKNOWN"
484
+ d_iso = to_iso(ev.get("date_from"))
485
+
486
+ actor_id = add_node("actor", actor) if actor else None
487
+ place_id = add_node("place", place) if place else None
488
+
489
+ # Edge semantics: actor -> object; place is context (not endpoint)
490
+ if actor_id:
491
+ edges.append({
492
+ "source": actor_id,
493
+ "target": onode["id"],
494
+ "label": etype,
495
+ "date": d_iso,
496
+ "weight": 1.0, # client may recompute with risk overlays
497
+ "source_ref": ev.get("source_ref"),
498
+ "policy": _policy_hits_for_date(d_iso)
499
+ })
500
+
501
+ # Optional: object -> place (to visualize locations)
502
+ if place_id and place:
503
+ edges.append({
504
+ "source": onode["id"],
505
+ "target": place_id,
506
+ "label": "LOCATED",
507
+ "date": d_iso,
508
+ "weight": 0.5,
509
+ "source_ref": ev.get("source_ref"),
510
+ "policy": _policy_hits_for_date(d_iso)
511
+ })
512
+
513
+ return {"nodes": list(nodes_map.values()), "edges": edges}
514
+
515
+ def build_timeline_from_events_and_sentences(events, sentences):
516
+ """Simple list items for any timeline widget."""
517
+ items = []
518
+ s_by_seq = {s["seq"]: s["sentence"] for s in sentences}
519
+ for ev in events:
520
+ start = to_iso(ev.get("date_from"))
521
+ end = to_iso(ev.get("date_to"))
522
+ title = ev.get("event_type") or "Event"
523
+ txt = None
524
+ # Try to pull the nearest sentence by seq if present
525
+ # (ingest stored seq starting at 0)
526
+ for k in (0, 1, 2, 3):
527
+ if k in s_by_seq:
528
+ txt = s_by_seq[k]; break
529
+ items.append({
530
+ "title": title,
531
+ "start_date": start,
532
+ "end_date": end,
533
+ "text": txt or "",
534
+ "source_ref": ev.get("source_ref")
535
+ })
536
+ return items
537
 
538
  # ───────────────────────────────────────────────────────────────────────────────
539
  # ROUTES
540
  # ───────────────────────────────────────────────────────────────────────────────
541
 
 
 
 
 
542
  @app.get("/")
543
  def root():
544
+ return jsonify({"ok": True, "service": "provenance-radar-api", "device": _DEVICE_INFO})
545
+
546
+ @app.get("/api/health")
547
+ @with_db_retry
548
+ def health():
549
+ try:
550
+ start_time = time.time()
551
+ with cursor() as cur:
552
+ cur.execute("SELECT COUNT(*) AS c FROM objects"); objects = cur.fetchone()["c"]
553
+ cur.execute("SELECT COUNT(*) AS c FROM provenance_sentences"); sentences = cur.fetchone()["c"]
554
+ cur.execute("SELECT COUNT(*) AS c FROM risk_signals"); risks = cur.fetchone()["c"]
555
+
556
+ db_latency = round((time.time() - start_time) * 1000, 2)
557
+
558
+ return jsonify({
559
+ "ok": True,
560
+ "device": _DEVICE_INFO,
561
+ "db_latency_ms": db_latency,
562
+ "counts": {
563
+ "objects": objects,
564
+ "sentences": sentences,
565
+ "risk_signals": risks
566
+ }
567
+ })
568
+ except Exception as e:
569
+ log.exception("health failed")
570
+ return jsonify({
571
+ "ok": False,
572
+ "error": str(e),
573
+ "db_status": "unavailable"
574
+ }), 503
575
+
576
+ @app.get("/api/policy/windows")
577
+ def policy_windows():
578
+ return jsonify({"ok": True, "windows": POLICY_WINDOWS})
579
+
580
+
581
+ @app.get("/api/leads")
582
+ @with_db_retry
583
+ def get_leads():
584
+ limit = max(1, min(int(request.args.get("limit", 50)), 200))
585
+ min_score = float(request.args.get("min_score", 0))
586
+ source = request.args.get("source")
587
+
588
+ sql = (
589
+ "SELECT object_id, source, title, creator, risk_score, top_signals "
590
+ "FROM flagged_leads WHERE risk_score >= %s "
591
+ )
592
+ args = [min_score]
593
+ if source:
594
+ sql += " AND source = %s "
595
+ args.append(source)
596
+ sql += " LIMIT %s"
597
+ args.append(limit)
598
+
599
+ with cursor() as cur:
600
+ cur.execute(sql, args)
601
+ rows = cur.fetchall()
602
+
603
+ for r in rows:
604
+ _apply_normalized_risk_inplace(r)
605
+
606
+ log.info("[RISK] /api/leads called | fetched=%s limit=%s min_score=%s source=%s",
607
+ len(rows), limit, min_score, source or "ALL")
608
+
609
+ for i, r in enumerate(rows[:5], start=1):
610
+ raw_ratio = _to_float(r.get("risk_score_raw"))
611
+ raw_pct = None if raw_ratio is None else round(raw_ratio * 100.0, 2)
612
+ norm_ratio= _to_float(r.get("risk_score")) # 0–1
613
+ norm_pct = None if norm_ratio is None else round(norm_ratio * 100.0) # shown by UI
614
+
615
+ log.info(
616
+ "[RISK] lead %d/%d | object_id=%s | title=%s | raw_ratio=%.3f | raw_pct=%s | norm_ratio=%.3f | norm_pctβ‰ˆ%s%%",
617
+ i, min(5, len(rows)),
618
+ r.get("object_id"),
619
+ (r.get("title") or "")[:80],
620
+ raw_ratio if raw_ratio is not None else -1.0,
621
+ f"{raw_pct:.0f}" if raw_pct is not None else "NA",
622
+ norm_ratio if norm_ratio is not None else -1.0,
623
+ f"{norm_pct:.0f}" if norm_pct is not None else "NA",
624
+ )
625
+
626
+ resp = jsonify({"ok": True, "data": rows})
627
+ resp.headers["Cache-Control"] = "no-store, max-age=0"
628
+ return resp
629
+
630
+
631
+ @app.get("/api/object/<int:object_id>")
632
+ @with_db_retry
633
+ def object_detail(object_id: int):
634
+ with cursor() as cur:
635
+ cur.execute("SELECT *, image_url FROM objects WHERE object_id=%s", (object_id,))
636
+ obj = cur.fetchone()
637
+ if not obj:
638
+ return jsonify({"ok": False, "error": "not_found"}), 404
639
+
640
+ # --- Normalize + overwrite the field the client reads (0..1) -----------
641
+ raw_ratio = _to_float(obj.get("risk_score")) # e.g., 2.0 = 200%
642
+ norm_ratio = normalize_risk(raw_ratio) if raw_ratio is not None else None # 0..1
643
+ norm_0_99 = None if norm_ratio is None else round(norm_ratio * 100.0, 2) # reference
644
+
645
+ obj["risk_score_raw"] = raw_ratio
646
+ obj["risk_score_norm_0_99"] = norm_0_99
647
+ obj["risk_score"] = norm_ratio # what the UI already reads
648
+ obj["risk_score_normalized"]= norm_ratio # alias
649
+
650
+ # --- Log one line per object fetch (visible on HF console) -------------
651
+ log.info(
652
+ "[RISK] /api/object | object_id=%s | raw_ratio=%s | raw_pct=%s | norm_ratio=%s | norm_pctβ‰ˆ%s%%",
653
+ object_id,
654
+ f"{raw_ratio:.3f}" if raw_ratio is not None else "NA",
655
+ f"{raw_ratio*100:.0f}" if raw_ratio is not None else "NA",
656
+ f"{norm_ratio:.3f}" if norm_ratio is not None else "NA",
657
+ f"{norm_ratio*100:.0f}" if norm_ratio is not None else "NA",
658
+ )
659
+
660
+ # -----------------------------------------------------------------------
661
+ cur.execute("SELECT seq, sentence FROM provenance_sentences WHERE object_id=%s ORDER BY seq", (object_id,))
662
+ sents = cur.fetchall()
663
+
664
+ cur.execute("""SELECT event_type, date_from, date_to, place, actor, method, source_ref
665
+ FROM provenance_events WHERE object_id=%s
666
+ ORDER BY COALESCE(date_from,'0001-01-01')""", (object_id,))
667
+ events = cur.fetchall()
668
+
669
+ cur.execute("SELECT code, detail, weight FROM risk_signals WHERE object_id=%s ORDER BY weight DESC", (object_id,))
670
+ risks = cur.fetchall()
671
+
672
+ resp = jsonify({"ok": True, "object": obj, "sentences": sents, "events": events, "risks": risks})
673
+ resp.headers["Cache-Control"] = "no-store, max-age=0"
674
+ return resp
675
+
676
+
677
+
678
+
679
+ @app.get("/api/graph/<int:object_id>")
680
+ @with_db_retry
681
+ def graph(object_id: int):
682
+ with cursor() as cur:
683
+ cur.execute("SELECT object_id, source, title FROM objects WHERE object_id=%s", (object_id,))
684
+ obj = cur.fetchone()
685
+ if not obj:
686
+ return jsonify({"ok": False, "error": "not_found"}), 404
687
+
688
+ cur.execute("""SELECT event_type, date_from, date_to, place, actor, source_ref
689
+ FROM provenance_events WHERE object_id=%s
690
+ ORDER BY COALESCE(date_from,'0001-01-01')""", (object_id,))
691
+ events = cur.fetchall()
692
+
693
+ cur.execute("SELECT seq, sentence FROM provenance_sentences WHERE object_id=%s ORDER BY seq", (object_id,))
694
+ sents = cur.fetchall()
695
+
696
+ inferred = infer_events_from_sentences(sents)
697
+
698
+ # Prefer stored events; fill with inferred where stored is thin
699
+ merged = list(events)
700
+ if not merged or all((not e.get("actor") and not e.get("place")) for e in merged):
701
+ merged = inferred
702
+ else:
703
+ # add inferred items that add missing actor/place for the same year
704
+ have = {(e.get("actor"), e.get("place"), e.get("event_type"), to_iso(e.get("date_from"))): True for e in merged}
705
+ for e in inferred:
706
+ key = (e.get("actor"), e.get("place"), e.get("event_type"), to_iso(e.get("date_from")))
707
+ if key not in have:
708
+ merged.append(e)
709
+
710
+ g = build_graph_from_events(obj, merged)
711
+
712
+ # NEW: link successive actors to show chain of custody
713
+ actors_in_time = [ (to_iso(e.get("date_from")) or "0001-01-01", e.get("actor")) for e in merged if e.get("actor") ]
714
+ actors_in_time.sort(key=lambda x: x[0])
715
+ for i in range(len(actors_in_time) - 1):
716
+ a1 = actors_in_time[i][1]; a2 = actors_in_time[i+1][1]
717
+ if a1 and a2 and a1 != a2:
718
+ g["edges"].append({
719
+ "source": f"actor:{a1}",
720
+ "target": f"actor:{a2}",
721
+ "label": "TRANSFER",
722
+ "date": actors_in_time[i+1][0],
723
+ "weight": 0.8,
724
+ "policy": _policy_hits_for_date(actors_in_time[i+1][0]),
725
+ "source_ref": "link:sequence"
726
+ })
727
+
728
+ return jsonify({"ok": True, **g})
729
+
730
+ @app.get("/api/places/<int:object_id>")
731
+ @with_db_retry
732
+ def places(object_id: int):
733
+ with cursor() as cur:
734
+ cur.execute("""SELECT place, date_from FROM provenance_events WHERE object_id=%s""", (object_id,))
735
+ ev = cur.fetchall()
736
+ cur.execute("SELECT seq, sentence FROM provenance_sentences WHERE object_id=%s ORDER BY seq", (object_id,))
737
+ sents = cur.fetchall()
738
+
739
+ inferred = infer_events_from_sentences(sents)
740
+ all_places = []
741
+ for e in ev + inferred:
742
+ p = _clean(e.get("place"))
743
+ if p:
744
+ all_places.append({"place": p, "date": to_iso(e.get("date_from"))})
745
+
746
+ # unique by place, keep earliest date
747
+ agg = {}
748
+ for r in all_places:
749
+ d = r["date"] or "9999-12-31"
750
+ if r["place"] not in agg or d < (agg[r["place"]].get("date") or "9999-12-31"):
751
+ agg[r["place"]] = r
752
+
753
+ out = []
754
+ for p, info in agg.items():
755
+ geo = geocode_place_cached(p) # may be None if geocoding blocked
756
+ out.append({"place": p, "date": info.get("date"), "lat": (geo or {}).get("lat"), "lon": (geo or {}).get("lon")})
757
+
758
+ # order chronologically for path drawing
759
+ out.sort(key=lambda x: x.get("date") or "9999-12-31")
760
+ return jsonify({"ok": True, "places": out})
761
+
762
+ @app.get("/api/timeline/<int:object_id>")
763
+ @with_db_retry
764
+ def timeline(object_id: int):
765
+ with cursor() as cur:
766
+ cur.execute("SELECT seq, sentence FROM provenance_sentences WHERE object_id=%s ORDER BY seq", (object_id,))
767
+ sents = cur.fetchall()
768
+ cur.execute("""SELECT event_type, date_from, date_to, place, actor, source_ref
769
+ FROM provenance_events WHERE object_id=%s
770
+ ORDER BY COALESCE(date_from,'0001-01-01')""", (object_id,))
771
+ events = cur.fetchall()
772
+ items = build_timeline_from_events_and_sentences(events, sents)
773
+ return jsonify({"ok": True, "items": items})
774
+
775
+ @app.get("/api/keyword")
776
+ @with_db_retry
777
+ def keyword_search():
778
+ q = (request.args.get("q") or "").strip()
779
+ limit = max(1, min(int(request.args.get("limit", 50)), 200))
780
+ if not q:
781
+ return jsonify({"ok": False, "error": "q required"}), 400
782
+ like = "%" + q.replace("%","").replace("_","") + "%"
783
+ with cursor() as cur:
784
+ cur.execute(
785
+ """SELECT ps.object_id, ps.seq, ps.sentence, o.source, o.title, o.creator
786
+ FROM provenance_sentences ps
787
+ JOIN objects o ON o.object_id = ps.object_id
788
+ WHERE ps.sentence LIKE %s
789
+ LIMIT %s""", (like, limit)
790
+ )
791
+ rows = cur.fetchall()
792
+ return jsonify({"ok": True, "query": q, "data": rows})
793
+
794
+
795
+ @app.post("/api/similar")
796
+ @with_db_retry
797
+ def similar_search():
798
+ payload = request.get_json(force=True) or {}
799
+ text = (payload.get("text") or "").strip()
800
+ limit = max(1, min(int(payload.get("limit", 20)), 100))
801
+ candidates = int(payload.get("candidates", max(200, limit * 10))) # pre-topK by sentences
802
+ source_filter = (payload.get("source") or "").strip().upper() # e.g., "AIC"
803
+
804
+ if not text:
805
+ return jsonify({"ok": False, "error": "text required"}), 400
806
+
807
+ # Embed (existing logic)
808
  try:
809
+ import torch
810
+ vec_t = _load_model().encode([text], batch_size=1, show_progress_bar=False, convert_to_tensor=True)
811
+ vec = (vec_t[0].detach().cpu().tolist() if isinstance(vec_t, torch.Tensor) else list(vec_t[0]))
812
  except Exception as e:
813
+ return jsonify({"ok": False, "error": f"embedding_unavailable: {e}"}), 503
814
+
815
+ vec_json = json.dumps(_pad(vec, VEC_DIM))
816
+ where_src = "WHERE o.source = %s" if source_filter else ""
817
+
818
+ # --- IMPORTANT: dedupe by object_id using window function -----------------
819
+ # We pull top 'candidates' sentences, join to objects (apply optional source),
820
+ # then keep only ROW_NUMBER() = 1 per object_id (best/closest sentence).
821
+ sql = f"""
822
+ WITH nn AS (
823
+ SELECT /*+ USE_INDEX(ps, hnsw_vec) */
824
+ ps.sent_id, ps.object_id, ps.seq, ps.sentence,
825
+ VEC_COSINE_DISTANCE(ps.embedding, CAST(%s AS VECTOR({VEC_DIM}))) AS distance
826
+ FROM provenance_sentences ps
827
+ ORDER BY distance
828
+ LIMIT %s
829
+ ),
830
+ ranked AS (
831
+ SELECT
832
+ nn.object_id,
833
+ nn.seq,
834
+ nn.sentence,
835
+ nn.distance,
836
+ o.source,
837
+ o.title,
838
+ o.creator,
839
+ ROW_NUMBER() OVER (PARTITION BY nn.object_id ORDER BY nn.distance ASC) AS rk
840
+ FROM nn
841
+ JOIN objects o ON o.object_id = nn.object_id
842
+ {where_src}
843
+ )
844
+ SELECT object_id, seq, sentence, source, title, creator, distance
845
+ FROM ranked
846
+ WHERE rk = 1
847
+ ORDER BY distance
848
+ LIMIT %s
849
+ """
850
+ params = [vec_json, candidates]
851
+ if source_filter:
852
+ params.append(source_filter)
853
+ params.append(limit)
854
+
855
+ try:
856
+ with cursor() as cur:
857
+ cur.execute(sql, params)
858
+ rows = cur.fetchall()
859
+ return jsonify({
860
+ "ok": True,
861
+ "device": _DEVICE_INFO,
862
+ "query": text,
863
+ "data": rows,
864
+ "meta": {"limit": limit, "candidates": candidates, "source": source_filter or None}
865
+ })
866
+ except OperationalError as e:
867
+ # TiDB OOM (1105) β†’ retry with smaller candidate set
868
+ if e.args and e.args[0] == 1105 and candidates > max(100, limit * 4):
869
+ smaller = max(100, limit * 4)
870
+ params2 = [vec_json, smaller]
871
+ if source_filter:
872
+ params2.append(source_filter)
873
+ params2.append(limit)
874
+ try:
875
+ with cursor() as cur:
876
+ cur.execute(sql, params2)
877
+ rows = cur.fetchall()
878
+ return jsonify({
879
+ "ok": True,
880
+ "device": _DEVICE_INFO,
881
+ "query": text,
882
+ "data": rows,
883
+ "meta": {"limit": limit, "candidates": smaller, "source": source_filter or None,
884
+ "note": "retried with smaller candidate set"}
885
+ })
886
+ except Exception as e2:
887
+ return jsonify({"ok": False, "error": f"oom_retry_failed: {e2}"}), 500
888
+ # Not OOM or still failed β†’ fall back to Python-side dedupe below
889
+ # (This keeps you resilient if window functions act up.)
890
+ try:
891
+ # Simple fallback: same as your original query, dedupe in Python.
892
+ where_src2 = "WHERE o.source = %s" if source_filter else ""
893
+ sql2 = f"""
894
+ WITH nn AS (
895
+ SELECT ps.sent_id, ps.object_id, ps.seq, ps.sentence,
896
+ VEC_COSINE_DISTANCE(ps.embedding, CAST(%s AS VECTOR({VEC_DIM}))) AS distance
897
+ FROM provenance_sentences ps
898
+ ORDER BY distance
899
+ LIMIT %s
900
+ )
901
+ SELECT nn.object_id, nn.seq, nn.sentence, o.source, o.title, o.creator, nn.distance
902
+ FROM nn
903
+ JOIN objects o ON o.object_id = nn.object_id
904
+ {where_src2}
905
+ ORDER BY nn.distance
906
+ LIMIT %s
907
+ """
908
+ params2 = [vec_json, candidates]
909
+ if source_filter:
910
+ params2.append(source_filter)
911
+ params2.append(limit * 5) # grab extra to allow dedupe
912
+ with cursor() as cur:
913
+ cur.execute(sql2, params2)
914
+ many = cur.fetchall()
915
+
916
+ # Python dedupe: keep first (closest) row per object_id
917
+ seen = set()
918
+ out = []
919
+ for r in many:
920
+ oid = r.get("object_id")
921
+ if oid in seen:
922
+ continue
923
+ seen.add(oid)
924
+ out.append(r)
925
+ if len(out) >= limit:
926
+ break
927
+
928
+ return jsonify({
929
+ "ok": True,
930
+ "device": _DEVICE_INFO,
931
+ "query": text,
932
+ "data": out,
933
+ "meta": {"limit": limit, "candidates": candidates, "source": source_filter or None,
934
+ "note": "python-dedup fallback"}
935
+ })
936
+ except Exception as e3:
937
+ return jsonify({"ok": False, "error": f"query_failed: {e} (fallback: {e3})"}), 500
938
+
939
+
940
+
941
+ @app.get("/api/vocab")
942
+ @with_db_retry
943
+ def vocab():
944
+ field = (request.args.get("field") or "").strip().lower()
945
+ limit = max(1, min(int(request.args.get("limit", 100)), 500))
946
+ if field not in {"actor", "place", "source", "culture"}:
947
+ return jsonify({"ok": False, "error": "field must be one of actor|place|source|culture"}), 400
948
+ if field in {"actor", "place"}:
949
+ sql = f"SELECT {field} AS v, COUNT(*) AS n FROM provenance_events WHERE {field} IS NOT NULL AND {field}<>'' GROUP BY {field} ORDER BY n DESC LIMIT %s"
950
+ elif field == "source":
951
+ sql = "SELECT source AS v, COUNT(*) AS n FROM objects GROUP BY source ORDER BY n DESC LIMIT %s"
952
+ else: # culture
953
+ sql = "SELECT culture AS v, COUNT(*) AS n FROM objects WHERE culture IS NOT NULL AND culture<>'' GROUP BY culture ORDER BY n DESC LIMIT %s"
954
+ with cursor() as cur:
955
+ cur.execute(sql, (limit,))
956
+ rows = cur.fetchall()
957
+ return jsonify({"ok": True, "field": field, "data": rows})
958
+
959
+ # ── Gemini-powered explanations ────────────────────────────────────────────────
960
+
961
+ @app.get("/api/explain/object/<int:object_id>")
962
+ @with_db_retry
963
+ def explain_object(object_id: int):
964
+ """Generate a concise, policy-aware research note for an object."""
965
+ with cursor() as cur:
966
+ cur.execute("SELECT object_id, source, title, creator, date_display, risk_score FROM objects WHERE object_id=%s", (object_id,))
967
+ obj = cur.fetchone()
968
+ if not obj:
969
+ return jsonify({"ok": False, "error": "not_found"}), 404
970
+ cur.execute("SELECT seq, sentence FROM provenance_sentences WHERE object_id=%s ORDER BY seq", (object_id,))
971
+ sents = cur.fetchall()
972
+ cur.execute("SELECT event_type, date_from, date_to, place, actor, source_ref FROM provenance_events WHERE object_id=%s ORDER BY COALESCE(date_from,'0001-01-01')", (object_id,))
973
+ events = cur.fetchall()
974
+
975
+ # Build a compact prompt (few sentences) to keep latency low
976
+ bullets = []
977
+ for s in sents[:8]: # keep prompt small
978
+ bullets.append(f"- {s['sentence']}")
979
+ evsumm = []
980
+ for e in events[:8]:
981
+ evsumm.append(f"{e.get('event_type')} @ {e.get('place') or 'β€”'} on {e.get('date_from') or 'β€”'} (actor: {e.get('actor') or 'β€”'})")
982
+
983
+ sys = ("You are assisting provenance researchers. Write a neutral, concise brief (120–180 words) that:\n"
984
+ "1) summarizes the chain of custody in plain language; 2) clearly marks any timeline gaps; "
985
+ "3) calls out potential red flags (e.g., confiscated/looted, sales during 1933–45, exports post-1970) "
986
+ "without making legal conclusions; 4) ends with a short 'Next leads' list (max 3).")
987
+ prompt = (
988
+ f"Object: {obj.get('title') or 'Untitled'} β€” {obj.get('creator') or ''} (source {obj['source']}). "
989
+ f"Display date: {obj.get('date_display') or 'n/a'}. Current risk_score={obj.get('risk_score', 0)}.\n\n"
990
+ f"Provenance sentences:\n" + "\n".join(bullets) + "\n\n"
991
+ f"Structured events (first 8):\n- " + "\n- ".join(evsumm) + "\n\n"
992
+ f"Policy windows to consider: Nazi era 1933–1945; UNESCO 1970 onwards."
993
+ )
994
+ text = gemini_explain(prompt, sys=sys)
995
+ return jsonify({"ok": True, "model": EXPLAIN_MODEL, "note": text})
996
+
997
+ @app.post("/api/explain/text")
998
+ def explain_text():
999
+ """Explain a specific provenance sentence or user query with policy context."""
1000
+ payload = request.get_json(force=True) or {}
1001
+ sentence = (payload.get("text") or "").strip()
1002
+ if not sentence:
1003
+ return jsonify({"ok": False, "error": "text required"}), 400
1004
+ sys = ("Explain this text as a provenance note for curators. "
1005
+ "Be precise and cautious; highlight possible red flags tied to 1933–1945 and post-1970 export rules.")
1006
+ prompt = f"""Explain and contextualize this provenance fragment:\n\n{sentence}."""
1007
+ text = gemini_explain(prompt, sys=sys)
1008
+ return jsonify({"ok": True, "model": EXPLAIN_MODEL, "explanation": text})
1009
 
1010
  # ──────────────────────────────────────────────────────────���────────────────────
1011
+ # MAIN (Spaces expects 7860)
1012
  # ───────────────────────────────────────────────────────────────────────────────
 
1013
  if __name__ == "__main__":
 
 
 
 
 
 
 
1014
  port = int(os.environ.get("PORT", "7860"))
1015
  app.run(host="0.0.0.0", port=port, debug=False)