Spaces:
Sleeping
Sleeping
| """ | |
| app.py | |
| ------ | |
| Flask API for the DDI Clinical Decision Support System. | |
| Endpoints: | |
| GET /health -- liveness check | |
| POST /api/check -- single drug pair DDI check (RAG) | |
| POST /api/check/batch -- multiple pairs in one request | |
| POST /api/check/compare -- side-by-side: Exact / TF-IDF / RAG | |
| GET /api/drug/search?q=<name> -- drug name lookup / autocomplete | |
| All responses are JSON. The RAG pipeline (step7) is loaded once at startup. | |
| """ | |
| import os, sys, time, json, threading, collections as _col, datetime as _dt, re | |
| import pandas as pd | |
| from flask import Flask, request, jsonify, render_template | |
| # --------------------------------------------------------------------------- | |
| # Bootstrap: add pipeline/ to path and load .env | |
| # --------------------------------------------------------------------------- | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| sys.path.insert(0, os.path.join(BASE_DIR, "pipeline")) | |
| def load_env(): | |
| env_path = os.path.join(BASE_DIR, ".env") | |
| if os.path.exists(env_path): | |
| with open(env_path) as f: | |
| for line in f: | |
| line = line.strip() | |
| if line and not line.startswith("#") and "=" in line: | |
| k, v = line.split("=", 1) | |
| os.environ[k.strip()] = v.strip() | |
| load_env() | |
| os.environ["TRANSFORMERS_NO_TF"] = "1" | |
| os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" | |
| # --------------------------------------------------------------------------- | |
| # App | |
| # --------------------------------------------------------------------------- | |
| app = Flask(__name__) | |
| app.config["JSON_SORT_KEYS"] = False | |
| app.config["MAX_CONTENT_LENGTH"] = 512 * 1024 # 512 KB max request body | |
| # --------------------------------------------------------------------------- | |
| # Security helpers | |
| # --------------------------------------------------------------------------- | |
| _CTRL = re.compile(r'[\x00-\x1f\x7f]') | |
| def _sanitize(s: str, max_len: int = 120) -> str: | |
| if not isinstance(s, str): | |
| return "" | |
| return _CTRL.sub('', s.strip())[:max_len] | |
| def _security_headers(response): | |
| h = response.headers | |
| h['X-Content-Type-Options'] = 'nosniff' | |
| # X-Frame-Options intentionally omitted: HF Spaces embeds the app in an | |
| # iframe from huggingface.co — any DENY/SAMEORIGIN value breaks that. | |
| # Framing is instead governed by the CSP frame-ancestors directive below. | |
| h['X-XSS-Protection'] = '1; mode=block' | |
| h['Referrer-Policy'] = 'strict-origin-when-cross-origin' | |
| # Inline scripts/styles needed for our vanilla templates. | |
| # frame-ancestors: allow HF Spaces embedding, block all other third parties. | |
| h['Content-Security-Policy'] = ( | |
| "default-src 'self'; " | |
| "script-src 'self' 'unsafe-inline'; " | |
| "style-src 'self' 'unsafe-inline'; " | |
| "img-src 'self' data:; " | |
| "connect-src 'self'; " | |
| "frame-ancestors 'self' https://huggingface.co https://*.hf.space;" | |
| ) | |
| return response | |
| # Lazy-loaded components | |
| _dict_ready = False | |
| _drug_names = None # list of approved drug names for search (sorted by popularity) | |
| _drug_pop = None # dict drugbank_id -> interaction count | |
| _ddi_lookup = None # primary: frozenset({id_a,id_b}) -> description — O(1) dict lookup | |
| _dict_lock = threading.Lock() | |
| # --------------------------------------------------------------------------- | |
| # Live stats (in-memory — resets on server restart by design) | |
| # --------------------------------------------------------------------------- | |
| _stats = { | |
| "total": 0, "documented": 0, "gnn_predicted": 0, | |
| "not_found": 0, "drug_not_found": 0, "total_ms": 0.0, | |
| } | |
| _recent_queries = _col.deque(maxlen=200) | |
| _pair_counts = _col.Counter() | |
| _server_start = _dt.datetime.utcnow() | |
| def _record_query(source: str, drug_a: str, drug_b: str, elapsed_ms: float): | |
| _stats["total"] += 1 | |
| _stats["total_ms"] += elapsed_ms | |
| if source in _stats: | |
| _stats[source] += 1 | |
| _recent_queries.appendleft({ | |
| "ts": _dt.datetime.utcnow().strftime("%H:%M:%S"), | |
| "drug_a": drug_a[:35], | |
| "drug_b": drug_b[:35], | |
| "source": source, | |
| "ms": round(elapsed_ms), | |
| }) | |
| if source in ("documented", "gnn_predicted", "not_found"): | |
| key = " × ".join(sorted([drug_a.lower(), drug_b.lower()])) | |
| _pair_counts[key] += 1 | |
| def _init_dict(): | |
| """ | |
| Load everything needed for /api/check: | |
| - DDI lookup dict frozenset({id_a, id_b}) -> description (O(1) DrugBank lookup) | |
| - drug name list for autocomplete (sorted by interaction count) | |
| - drug name + synonym caches for resolve_drug() | |
| """ | |
| global _dict_ready, _drug_names, _ddi_lookup | |
| if _dict_ready: | |
| return | |
| with _dict_lock: | |
| if _dict_ready: | |
| return | |
| print("[DDI] Loading pipeline ...", flush=True) | |
| t0 = time.time() | |
| approved = os.path.join(BASE_DIR, "data", "step3_approved") | |
| ddi_path = os.path.join(approved, "drug_interactions_dedup.csv") | |
| drug_path = os.path.join(approved, "drugs.csv") | |
| # ── Build primary DDI lookup dict ────────────────────────────────── | |
| # frozenset({id_a, id_b}) → description string | |
| # Order-independent, O(1) lookup, guaranteed to find any documented pair | |
| if os.path.exists(ddi_path): | |
| ddi = pd.read_csv(ddi_path) | |
| _ddi_lookup = {} | |
| for _, row in ddi.iterrows(): | |
| key = frozenset([row["drugbank_id_a"], row["drugbank_id_b"]]) | |
| desc = str(row["description"]).strip() if pd.notna(row["description"]) else "" | |
| _ddi_lookup[key] = desc | |
| print(f"[DDI] Lookup dict: {len(_ddi_lookup):,} pairs", flush=True) | |
| else: | |
| _ddi_lookup = {} | |
| # ── Drug name list for autocomplete ──────────────────────────────── | |
| if os.path.exists(drug_path): | |
| df = pd.read_csv(drug_path, usecols=["drugbank_id", "name"]) | |
| if _ddi_lookup: | |
| counts = {} | |
| for key in _ddi_lookup: | |
| for did in key: | |
| counts[did] = counts.get(did, 0) + 1 | |
| df["_pop"] = df["drugbank_id"].map(counts).fillna(0) | |
| df = df.sort_values("_pop", ascending=False).drop(columns="_pop") | |
| _drug_names = df.to_dict("records") | |
| else: | |
| _drug_names = [] | |
| # ── Warm up drug-resolution caches ──────────────────────────────── | |
| import ddi_query as rag | |
| rag.get_drugs_df() # drugs.csv (name <-> id table) | |
| rag.get_synonym_map() # drug_attributes.csv + products.csv synonym/brand map | |
| # ── Pre-warm GNN model so first query latency is not inflated ───── | |
| try: | |
| import gnn_predictor as _gnn_warm | |
| _gnn_warm.is_available() # triggers model load in background thread | |
| print("[DDI] GNN predictor warmed up.", flush=True) | |
| except Exception as _we: | |
| print(f"[DDI] GNN warmup skipped: {_we}", flush=True) | |
| _dict_ready = True | |
| print(f"[DDI] Pipeline ready in {time.time()-t0:.1f}s", flush=True) | |
| # --------------------------------------------------------------------------- | |
| # Groq LLM helpers (chatbot — NER + explanation) | |
| # --------------------------------------------------------------------------- | |
| _groq_client = None | |
| _groq_lock = threading.Lock() | |
| GROQ_NER_MODEL = "llama-3.1-8b-instant" # fast, free — structured NER task | |
| GROQ_CHAT_MODEL = "llama-3.3-70b-versatile" # better quality for explanation | |
| def _get_groq(): | |
| """Lazy Groq client. Returns None when GROQ_API_KEY is not set.""" | |
| global _groq_client | |
| if _groq_client is not None: | |
| return _groq_client | |
| with _groq_lock: | |
| if _groq_client is not None: | |
| return _groq_client | |
| api_key = os.environ.get("GROQ_API_KEY", "").strip() | |
| if not api_key: | |
| return None | |
| try: | |
| from groq import Groq | |
| _groq_client = Groq(api_key=api_key) | |
| print("[DDI] Groq client initialised.", flush=True) | |
| except Exception as e: | |
| print(f"[DDI] Groq init error: {e}", flush=True) | |
| return _groq_client | |
| def extract_drugs_nlp(text: str) -> list: | |
| """ | |
| Call Groq LLM to extract drug / medication names from free text (NER). | |
| Returns a list of drug name strings. Falls back to [] on error / no key. | |
| """ | |
| client = _get_groq() | |
| if not client: | |
| return [] | |
| prompt = ( | |
| "You are a medical named-entity recognition system. " | |
| "Extract every drug, medication, or supplement name mentioned in the text below. " | |
| "Return ONLY a valid JSON array of strings — no explanation, no markdown fences.\n" | |
| "Example output: [\"Warfarin\", \"Aspirin\"]\n" | |
| "If no drugs are found, return: []\n\n" | |
| f"Text: {text}" | |
| ) | |
| try: | |
| resp = client.chat.completions.create( | |
| model=GROQ_NER_MODEL, | |
| messages=[{"role": "user", "content": prompt}], | |
| temperature=0.0, | |
| max_tokens=200, | |
| ) | |
| raw = resp.choices[0].message.content.strip() | |
| # Strip markdown fences if present | |
| if "```" in raw: | |
| raw = raw.split("```")[1].strip() | |
| if raw.lower().startswith("json"): | |
| raw = raw[4:].strip() | |
| drugs = json.loads(raw) | |
| if isinstance(drugs, list): | |
| return [str(d).strip() for d in drugs if str(d).strip()] | |
| return [] | |
| except Exception as e: | |
| print(f"[DDI] NER error: {e}", flush=True) | |
| return [] | |
| def _fallback_reply(pair_results: list, drugs_not_found: list) -> str: | |
| """Template-based reply used when Groq is unavailable.""" | |
| lines = [] | |
| if drugs_not_found: | |
| lines.append( | |
| f"⚠ I couldn't find these in my database: **{', '.join(drugs_not_found)}**. " | |
| "Please check the spelling or use the Checker page for autocomplete suggestions." | |
| ) | |
| for r in pair_results: | |
| src = r.get("source", "") | |
| name_a = r.get("drug_a", {}).get("resolved", "Drug A") | |
| name_b = r.get("drug_b", {}).get("resolved", "Drug B") | |
| desc = r.get("interaction_description", "") | |
| if src == "documented": | |
| lines.append( | |
| f"⚠ **{name_a}** and **{name_b}** have a **documented interaction** in DrugBank. " | |
| "Please consult a pharmacist or clinician before combining these medications. " | |
| f"{desc}" | |
| ) | |
| elif src == "gnn_predicted": | |
| prob = (r.get("gnn") or {}).get("probability") or 0 | |
| lines.append( | |
| f"🔶 The GNN model predicts a **possible interaction** between **{name_a}** " | |
| f"and **{name_b}** (probability {prob:.0%}). " | |
| "This is not confirmed in DrugBank — consult a pharmacist before use." | |
| ) | |
| else: | |
| lines.append( | |
| f"✓ No interaction between **{name_a}** and **{name_b}** was found in DrugBank " | |
| "or predicted by the GNN. Absence of data does not guarantee safety — always " | |
| "consult a healthcare professional when in doubt." | |
| ) | |
| if not lines: | |
| return ( | |
| "I couldn't identify any drug names in your message. " | |
| "Try mentioning specific medication names, e.g. \"Is warfarin safe with aspirin?\"" | |
| ) | |
| return "\n\n".join(lines) | |
| def generate_chat_reply( | |
| user_message: str, | |
| drugs_resolved: list, | |
| pair_results: list, | |
| drugs_not_found: list, | |
| ) -> str: | |
| """ | |
| Generate a plain-language conversational reply via Groq LLM. | |
| Falls back to _fallback_reply() when no API key is configured. | |
| """ | |
| client = _get_groq() | |
| if not client: | |
| return _fallback_reply(pair_results, drugs_not_found) | |
| # Build structured context for the LLM | |
| ctx_lines = [] | |
| for r in pair_results: | |
| src = r.get("source", "") | |
| name_a = r.get("drug_a", {}).get("resolved", "") | |
| name_b = r.get("drug_b", {}).get("resolved", "") | |
| desc = r.get("interaction_description", "") | |
| if src == "documented": | |
| ctx_lines.append( | |
| f"- {name_a} + {name_b}: DOCUMENTED in DrugBank. {desc}" | |
| ) | |
| elif src == "gnn_predicted": | |
| prob = (r.get("gnn") or {}).get("probability") or 0 | |
| ctx_lines.append( | |
| f"- {name_a} + {name_b}: PREDICTED by GNN (probability {prob:.0%}). " | |
| "Not confirmed in DrugBank." | |
| ) | |
| else: | |
| ctx_lines.append(f"- {name_a} + {name_b}: No interaction found.") | |
| if drugs_not_found: | |
| ctx_lines.append( | |
| f"- Drugs not found in database: {', '.join(drugs_not_found)}" | |
| ) | |
| context = "\n".join(ctx_lines) if ctx_lines else "No pairs analysed." | |
| system_prompt = ( | |
| "You are a friendly medication safety assistant helping everyday people understand " | |
| "whether their medications might interact.\n" | |
| "Rules:\n" | |
| "1. Be concise (3-5 sentences), warm, and use plain everyday language. No technical terms.\n" | |
| "2. Never mention 'GNN', 'Graph Neural Network', 'model', 'algorithm', 'DrugBank', " | |
| "'probability', or any technical system details. Just say 'our system flagged' or " | |
| "'this combination may carry a risk'.\n" | |
| "3. For documented interactions: explain what the interaction actually means in plain terms " | |
| "(e.g. 'this can increase bleeding risk'), then recommend speaking to a pharmacist.\n" | |
| "4. For predicted interactions: say something like 'we found a possible concern' — " | |
| "do not explain how or why it was detected.\n" | |
| "5. Never diagnose or prescribe. Frame everything as 'worth checking with your pharmacist'.\n" | |
| "6. Mention specific drug names. Be reassuring but honest." | |
| ) | |
| user_prompt = ( | |
| f"User message: \"{user_message}\"\n\n" | |
| f"DDI analysis results:\n{context}\n\n" | |
| "Write a conversational reply explaining these results to the user." | |
| ) | |
| try: | |
| resp = client.chat.completions.create( | |
| model=GROQ_CHAT_MODEL, | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt}, | |
| ], | |
| temperature=0.3, | |
| max_tokens=450, | |
| ) | |
| return resp.choices[0].message.content.strip() | |
| except Exception as e: | |
| print(f"[DDI] Groq chat error: {e}", flush=True) | |
| return _fallback_reply(pair_results, drugs_not_found) | |
| # --------------------------------------------------------------------------- | |
| # Routes | |
| # --------------------------------------------------------------------------- | |
| def index(): | |
| return render_template("landing.html") | |
| def checker_page(): | |
| return render_template("checker.html") | |
| def chat_page(): | |
| return render_template("chat.html") | |
| def results_page(): | |
| """Model performance comparison page (warm + cold-start evaluation).""" | |
| import json as _json | |
| results_path = os.path.join(BASE_DIR, "data", "evaluation", "baselines_results.json") | |
| results_data = {} | |
| try: | |
| with open(results_path) as f: | |
| results_data = _json.load(f) | |
| except Exception: | |
| pass | |
| return render_template("results.html", | |
| results=results_data.get("results", {}), | |
| cold_results=results_data.get("cold_results", {}), | |
| config=results_data.get("config", {})) | |
| def responsible_page(): | |
| """Responsible ML page — explainability, fairness, privacy, robustness.""" | |
| import json as _json | |
| bias_path = os.path.join(BASE_DIR, "data", "evaluation", "responsible_ml_bias.json") | |
| robust_path = os.path.join(BASE_DIR, "data", "evaluation", "responsible_ml_robust.json") | |
| gnn_auc_path = os.path.join(BASE_DIR, "data", "evaluation", "responsible_ml_gnn_auc.json") | |
| bias_data, robust_data, gnn_auc_data = {}, {}, {} | |
| try: | |
| with open(bias_path) as f: | |
| bias_data = _json.load(f) | |
| except Exception: | |
| pass | |
| try: | |
| with open(robust_path) as f: | |
| robust_data = _json.load(f) | |
| except Exception: | |
| pass | |
| try: | |
| with open(gnn_auc_path) as f: | |
| gnn_auc_data = _json.load(f) | |
| except Exception: | |
| pass | |
| return render_template("responsible.html", | |
| bias=bias_data, | |
| robust=robust_data, | |
| gnn_auc=gnn_auc_data) | |
| def health(): | |
| """Liveness / readiness check.""" | |
| return jsonify({ | |
| "status": "ok", | |
| "dict_loaded": _dict_ready, # DDI lookup dict + drug caches ready | |
| }) | |
| def check_pair(): | |
| """ | |
| Two-stage DDI check: | |
| Stage 1 — Dict lookup (O(1) against 824K documented DrugBank pairs). | |
| Stage 2 — GNN (only if dict lookup finds nothing). | |
| If both say no → return not_found. | |
| Request body (JSON): | |
| { "drug_a": "Warfarin", "drug_b": "Aspirin" } | |
| Response: | |
| { | |
| "drug_a": { "query": "...", "resolved": "...", "id": "..." }, | |
| "drug_b": { ... }, | |
| "source": "documented" | "gnn_predicted" | "not_found", | |
| "found": true | false, | |
| "interaction_type": "...", | |
| "interaction_description": "...", | |
| "gnn": { "probability": 0.87, "mock": true } | null, | |
| "disclaimer": "..." | null, | |
| "error": null | |
| } | |
| """ | |
| _init_dict() | |
| _t0 = time.time() | |
| GNN_THRESHOLD = 0.43 | |
| body = request.get_json(silent=True) or {} | |
| drug_a = _sanitize(body.get("drug_a") or "") | |
| drug_b = _sanitize(body.get("drug_b") or "") | |
| if not drug_a or not drug_b: | |
| return jsonify({"error": "Both 'drug_a' and 'drug_b' are required."}), 400 | |
| import ddi_query as rag | |
| try: | |
| id_a, name_a = rag.resolve_drug(drug_a) | |
| except ValueError: | |
| return jsonify({ | |
| "source": "drug_not_found", | |
| "drug_not_found": drug_a, | |
| "found": False, | |
| "error": None, | |
| }) | |
| try: | |
| id_b, name_b = rag.resolve_drug(drug_b) | |
| except ValueError: | |
| return jsonify({ | |
| "source": "drug_not_found", | |
| "drug_not_found": drug_b, | |
| "found": False, | |
| "error": None, | |
| }) | |
| # ── Resolve all components for combination-product brands ───────────── | |
| def _get_components(query, primary_id, primary_name): | |
| """Return [(id, name), ...] for all drugs in a brand (or just the primary).""" | |
| comps = rag.get_brand_components(query) | |
| if len(comps) > 1: | |
| return comps | |
| return [(primary_id, primary_name)] | |
| comps_a = _get_components(drug_a, id_a, name_a) | |
| comps_b = _get_components(drug_b, id_b, name_b) | |
| brand_notice_a = ( | |
| {"primary": name_a, "others": [c[1] for c in comps_a if c[1] != name_a], | |
| "all": [c[1] for c in comps_a]} | |
| if len(comps_a) > 1 else None | |
| ) | |
| brand_notice_b = ( | |
| {"primary": name_b, "others": [c[1] for c in comps_b if c[1] != name_b], | |
| "all": [c[1] for c in comps_b]} | |
| if len(comps_b) > 1 else None | |
| ) | |
| is_combo = len(comps_a) > 1 or len(comps_b) > 1 | |
| # ── Combination-product path: all pairs shown equally ───────────────── | |
| if is_combo: | |
| all_checks = [] | |
| for (a_id, a_name) in comps_a: | |
| for (b_id, b_name) in comps_b: | |
| pair_desc = _ddi_lookup.get(frozenset([a_id, b_id])) | |
| if pair_desc is not None: | |
| all_checks.append({ | |
| "drug_a": {"resolved": a_name, "id": a_id}, | |
| "drug_b": {"resolved": b_name, "id": b_id}, | |
| "source": "documented", | |
| "found": True, | |
| "interaction_description": pair_desc, | |
| "gnn": None, | |
| "explanation": None, | |
| }) | |
| else: | |
| # Try GNN for undocumented combo pairs | |
| pair_gnn = None | |
| pair_explanation = None | |
| pair_source = "not_found" | |
| pair_found = False | |
| pair_desc_text = f"No documented interaction found for {a_name} and {b_name}." | |
| try: | |
| import gnn_predictor | |
| pair_gnn = gnn_predictor.predict(a_id, b_id) | |
| gnn_prob = pair_gnn.get("probability") if pair_gnn else None | |
| if gnn_prob is not None and gnn_prob >= GNN_THRESHOLD: | |
| pair_source = "gnn_predicted" | |
| pair_found = True | |
| pair_explanation = gnn_predictor.explain(a_id, b_id) | |
| pair_desc_text = ( | |
| f"No documented interaction found in DrugBank for {a_name} and {b_name}. " | |
| f"The GNN predicts a possible interaction (probability {gnn_prob:.0%})." | |
| ) | |
| except Exception: | |
| pass | |
| all_checks.append({ | |
| "drug_a": {"resolved": a_name, "id": a_id}, | |
| "drug_b": {"resolved": b_name, "id": b_id}, | |
| "source": pair_source, | |
| "found": pair_found, | |
| "interaction_description": pair_desc_text, | |
| "gnn": pair_gnn, | |
| "explanation": pair_explanation, | |
| }) | |
| # Overall status = worst case across all pairs | |
| has_documented = any(c["source"] == "documented" for c in all_checks) | |
| has_gnn = any(c["source"] == "gnn_predicted" for c in all_checks) | |
| overall_source = "documented" if has_documented else "gnn_predicted" if has_gnn else "not_found" | |
| overall_found = has_documented or has_gnn | |
| result = { | |
| "drug_a": {"query": drug_a, "resolved": name_a, "id": id_a}, | |
| "drug_b": {"query": drug_b, "resolved": name_b, "id": id_b}, | |
| "source": overall_source, | |
| "found": overall_found, | |
| "all_checks": all_checks, | |
| "brand_notice_a": brand_notice_a, | |
| "brand_notice_b": brand_notice_b, | |
| "interaction_description": None, | |
| "gnn": None, | |
| "disclaimer": None, | |
| "error": None, | |
| } | |
| _record_query(overall_source, name_a, name_b, (time.time() - _t0) * 1000) | |
| return jsonify(result) | |
| # ── Standard single × single path ──────────────────────────────────── | |
| # Stage 1: dict lookup (O(1)) | |
| lookup_key = frozenset([id_a, id_b]) | |
| dict_desc = _ddi_lookup.get(lookup_key) | |
| dict_hit = dict_desc is not None | |
| # Stage 2: GNN (only when dict has no result) | |
| gnn_result = None | |
| gnn_found = False | |
| gnn_prob = None | |
| if not dict_hit: | |
| try: | |
| import gnn_predictor | |
| gnn_result = gnn_predictor.predict(id_a, id_b) | |
| except Exception as e: | |
| gnn_result = {"found": None, "probability": None, "note": str(e), "mock": True} | |
| gnn_prob = gnn_result.get("probability") if gnn_result else None | |
| gnn_found = (gnn_prob is not None) and (gnn_prob >= GNN_THRESHOLD) | |
| if dict_hit: | |
| source = "documented" | |
| result = { | |
| "drug_a": {"query": drug_a, "resolved": name_a, "id": id_a}, | |
| "drug_b": {"query": drug_b, "resolved": name_b, "id": id_b}, | |
| "source": "documented", | |
| "found": True, | |
| "interaction_type": None, | |
| "interaction_description": dict_desc or ( | |
| f"Documented interaction between {name_a} and {name_b} in DrugBank." | |
| ), | |
| "gnn": None, | |
| "disclaimer": None, | |
| "brand_notice_a": brand_notice_a, | |
| "brand_notice_b": brand_notice_b, | |
| "error": None, | |
| } | |
| elif gnn_found: | |
| source = "gnn_predicted" | |
| explanation = gnn_predictor.explain(id_a, id_b) | |
| result = { | |
| "drug_a": {"query": drug_a, "resolved": name_a, "id": id_a}, | |
| "drug_b": {"query": drug_b, "resolved": name_b, "id": id_b}, | |
| "source": "gnn_predicted", | |
| "found": True, | |
| "interaction_type": "predicted", | |
| "interaction_description": ( | |
| f"No documented interaction found in DrugBank for {name_a} and {name_b}. " | |
| f"The Graph Neural Network predicts a possible interaction " | |
| f"(probability {gnn_prob:.0%})." | |
| ), | |
| "gnn": gnn_result, | |
| "explanation": explanation, | |
| "disclaimer": ( | |
| "This interaction is not documented in DrugBank but is predicted " | |
| "by the Graph Neural Network based on pharmacological graph patterns. " | |
| "Consult a pharmacist or clinician before use. " | |
| "This is a clinical decision support signal, not a confirmed finding." | |
| ), | |
| "brand_notice_a": brand_notice_a, | |
| "brand_notice_b": brand_notice_b, | |
| "error": None, | |
| } | |
| else: | |
| source = "not_found" | |
| result = { | |
| "drug_a": {"query": drug_a, "resolved": name_a, "id": id_a}, | |
| "drug_b": {"query": drug_b, "resolved": name_b, "id": id_b}, | |
| "source": "not_found", | |
| "found": False, | |
| "interaction_type": None, | |
| "interaction_description": ( | |
| f"No interaction between {name_a} and {name_b} was found in DrugBank, " | |
| "and the GNN model did not predict one above the confidence threshold. " | |
| "Absence of a recorded interaction does not guarantee safety " | |
| "— always consult a pharmacist or clinician when in doubt." | |
| ), | |
| "gnn": gnn_result, | |
| "disclaimer": None, | |
| "brand_notice_a": brand_notice_a, | |
| "brand_notice_b": brand_notice_b, | |
| "error": None, | |
| } | |
| _record_query(source, name_a, name_b, (time.time() - _t0) * 1000) | |
| return jsonify(result) | |
| def chat_api(): | |
| """ | |
| Chatbot endpoint. | |
| Flow: | |
| 1. Groq NER — extract drug names from free-text message | |
| 2. resolve_drug() — map raw names to DrugBank IDs | |
| 3. Dict lookup + GNN — check every unique pair | |
| 4. Groq LLM — generate plain-language explanation | |
| Request body (JSON): | |
| { "message": "I'm taking warfarin and want to start aspirin, is that safe?" } | |
| Response: | |
| { | |
| "reply": "...", # LLM-generated explanation | |
| "drugs": [...], # resolved drug objects | |
| "drugs_not_found": [...], # names not in DrugBank | |
| "pairs": [...], # one entry per unique pair | |
| "llm_available": true, | |
| "error": null | |
| } | |
| """ | |
| _init_dict() | |
| body = request.get_json(silent=True) or {} | |
| message = _sanitize(body.get("message") or "", max_len=2000) | |
| if not message: | |
| return jsonify({"error": "message is required"}), 400 | |
| if len(message) > 2000: | |
| return jsonify({"error": "message too long (max 2000 chars)"}), 400 | |
| import ddi_query as rag | |
| llm_available = _get_groq() is not None | |
| # ── Step 1: NER ────────────────────────────────────────────────────── | |
| drug_names_raw = extract_drugs_nlp(message) | |
| # ── Step 2: Resolve each name ──────────────────────────────────────── | |
| drugs_resolved = [] # [{"query": str, "name": str, "id": str}] | |
| drugs_not_found = [] | |
| for raw in drug_names_raw: | |
| try: | |
| did, dname = rag.resolve_drug(raw) | |
| # Deduplicate by DrugBank ID | |
| if not any(d["id"] == did for d in drugs_resolved): | |
| drugs_resolved.append({"query": raw, "name": dname, "id": did}) | |
| except ValueError: | |
| if raw not in drugs_not_found: | |
| drugs_not_found.append(raw) | |
| # ── Step 3: Check every unique pair ────────────────────────────────── | |
| GNN_THRESHOLD = 0.5 | |
| pair_results = [] | |
| seen_pairs = set() | |
| for i, da in enumerate(drugs_resolved): | |
| for db in drugs_resolved[i + 1:]: | |
| pair_key = frozenset([da["id"], db["id"]]) | |
| if pair_key in seen_pairs: | |
| continue | |
| seen_pairs.add(pair_key) | |
| id_a, name_a = da["id"], da["name"] | |
| id_b, name_b = db["id"], db["name"] | |
| # Stage 1 — dict lookup | |
| dict_desc = _ddi_lookup.get(frozenset([id_a, id_b])) | |
| if dict_desc is not None: | |
| pair_results.append({ | |
| "drug_a": {"query": da["query"], "resolved": name_a, "id": id_a}, | |
| "drug_b": {"query": db["query"], "resolved": name_b, "id": id_b}, | |
| "source": "documented", | |
| "found": True, | |
| "interaction_description": dict_desc or ( | |
| f"Documented interaction between {name_a} and {name_b} in DrugBank." | |
| ), | |
| "gnn": None, | |
| "disclaimer": None, | |
| }) | |
| continue | |
| # Stage 2 — GNN | |
| gnn_result = None | |
| try: | |
| import gnn_predictor | |
| gnn_result = gnn_predictor.predict(id_a, id_b) | |
| except Exception as e: | |
| gnn_result = {"found": None, "probability": None, | |
| "note": str(e), "mock": True} | |
| gnn_prob = (gnn_result or {}).get("probability") | |
| gnn_found = gnn_prob is not None and gnn_prob >= GNN_THRESHOLD | |
| if gnn_found: | |
| explanation = gnn_predictor.explain(id_a, id_b) | |
| pair_results.append({ | |
| "drug_a": {"query": da["query"], "resolved": name_a, "id": id_a}, | |
| "drug_b": {"query": db["query"], "resolved": name_b, "id": id_b}, | |
| "source": "gnn_predicted", | |
| "explanation": explanation, | |
| "found": True, | |
| "interaction_description": ( | |
| f"No documented interaction found in DrugBank for {name_a} and {name_b}. " | |
| f"The GNN predicts a possible interaction (probability {gnn_prob:.0%})." | |
| ), | |
| "gnn": gnn_result, | |
| "disclaimer": ( | |
| "⚠ GNN prediction — not a confirmed DrugBank finding. " | |
| "Consult a pharmacist or clinician before use." | |
| ), | |
| }) | |
| else: | |
| pair_results.append({ | |
| "drug_a": {"query": da["query"], "resolved": name_a, "id": id_a}, | |
| "drug_b": {"query": db["query"], "resolved": name_b, "id": id_b}, | |
| "source": "not_found", | |
| "found": False, | |
| "interaction_description": ( | |
| f"No interaction between {name_a} and {name_b} was found " | |
| "in DrugBank or predicted by the GNN." | |
| ), | |
| "gnn": gnn_result, | |
| "disclaimer": None, | |
| }) | |
| # ── Step 4: LLM explanation ────────────────────────────────────────── | |
| reply = generate_chat_reply(message, drugs_resolved, pair_results, drugs_not_found) | |
| return jsonify({ | |
| "reply": reply, | |
| "drugs": drugs_resolved, | |
| "drugs_not_found": drugs_not_found, | |
| "pairs": pair_results, | |
| "llm_available": llm_available, | |
| "error": None, | |
| }) | |
| def check_batch(): | |
| """ | |
| Check DDI for multiple pairs in one request. | |
| Request body (JSON): | |
| { | |
| "pairs": [ | |
| {"drug_a": "Warfarin", "drug_b": "Aspirin"}, | |
| {"drug_a": "Metformin", "drug_b": "Lisinopril"} | |
| ], | |
| "top_k": 3 | |
| } | |
| Response: | |
| { "results": [ <same shape as /api/check for each pair> ] } | |
| """ | |
| _init_dict() | |
| body = request.get_json(silent=True) or {} | |
| pairs = body.get("pairs", []) | |
| if not pairs: | |
| return jsonify({"error": "'pairs' list is required and must not be empty."}), 400 | |
| if len(pairs) > 50: | |
| return jsonify({"error": "Maximum 50 pairs per batch request."}), 400 | |
| import ddi_query as rag | |
| results = [] | |
| for pair in pairs: | |
| drug_a = _sanitize(pair.get("drug_a") or "") | |
| drug_b = _sanitize(pair.get("drug_b") or "") | |
| if not drug_a or not drug_b: | |
| results.append({"error": "Both 'drug_a' and 'drug_b' are required."}) | |
| continue | |
| try: | |
| id_a, name_a = rag.resolve_drug(drug_a) | |
| id_b, name_b = rag.resolve_drug(drug_b) | |
| except ValueError as e: | |
| results.append({"drug_a": {"query": drug_a}, "drug_b": {"query": drug_b}, | |
| "error": str(e)}) | |
| continue | |
| lookup_key = frozenset([id_a, id_b]) | |
| dict_desc = _ddi_lookup.get(lookup_key) | |
| dict_hit = dict_desc is not None | |
| results.append({ | |
| "drug_a": {"query": drug_a, "resolved": name_a, "id": id_a}, | |
| "drug_b": {"query": drug_b, "resolved": name_b, "id": id_b}, | |
| "found": dict_hit, | |
| "interaction_type": None, | |
| "interaction_description": dict_desc if dict_hit else None, | |
| "error": None, | |
| }) | |
| return jsonify({"results": results}) | |
| def drug_search(): | |
| """ | |
| Drug name autocomplete / lookup. | |
| Query params: | |
| q -- search string (min 2 chars) | |
| limit -- max results (default 10, max 50) | |
| Response: | |
| { "results": [ {"id": "DB00682", "name": "Warfarin"}, ... ] } | |
| """ | |
| _init_dict() | |
| q = (request.args.get("q") or "").strip().lower() | |
| limit = min(int(request.args.get("limit", 10)), 50) | |
| if len(q) < 2: | |
| return jsonify({"results": []}) | |
| pop = _drug_pop or {} | |
| # 1. Canonical drug name matches | |
| seen_ids = set() | |
| canon = [] | |
| for d in (_drug_names or []): | |
| if q in d["name"].lower(): | |
| canon.append({"drugbank_id": d["drugbank_id"], "name": d["name"], | |
| "is_brand": False, "brand_name": None}) | |
| seen_ids.add(d["drugbank_id"]) | |
| canon.sort(key=lambda d: (not d["name"].lower().startswith(q), | |
| -pop.get(d["drugbank_id"], 0))) | |
| # 2. Brand / product name matches from products.csv | |
| import ddi_query as _raq | |
| _COMPANY_TOKENS = (' llc', ' inc', ' inc.', ' ltd', ' corp', ' gmbh', ' ag ') | |
| bc = getattr(_raq, "_brand_components", None) or {} | |
| bd = getattr(_raq, "_brand_display", None) or {} | |
| brand = [] | |
| brand_seen_lower = set() # deduplicate by brand name (one entry per brand) | |
| for brand_lower, entries in bc.items(): | |
| if q not in brand_lower: | |
| continue | |
| if len(brand_lower) > 45 or any(t in brand_lower for t in _COMPANY_TOKENS): | |
| continue | |
| if brand_lower in brand_seen_lower: | |
| continue | |
| brand_seen_lower.add(brand_lower) | |
| brand_display = bd.get(brand_lower, brand_lower.title()) | |
| # Use the primary (first) component's ID to represent the brand | |
| primary_did, primary_canonical = entries[0] | |
| brand.append({"drugbank_id": primary_did, "name": primary_canonical, | |
| "is_brand": True, "brand_name": brand_display}) | |
| seen_ids.add(primary_did) | |
| brand.sort(key=lambda d: (not d["brand_name"].lower().startswith(q), | |
| -pop.get(d["drugbank_id"], 0))) | |
| return jsonify({"results": (canon + brand)[:limit]}) | |
| def landing_page(): | |
| from flask import redirect | |
| return redirect("/") | |
| def about_page(): | |
| return render_template("about.html") | |
| def dashboard_page(): | |
| return render_template("dashboard.html") | |
| def api_stats(): | |
| uptime_s = int((_dt.datetime.utcnow() - _server_start).total_seconds()) | |
| h, rem = divmod(uptime_s, 3600) | |
| m, s = divmod(rem, 60) | |
| avg_ms = round(_stats["total_ms"] / max(_stats["total"], 1)) | |
| hit_rate = round(_stats["documented"] / max(_stats["total"], 1) * 100, 1) | |
| try: | |
| import gnn_predictor | |
| gnn_ok = gnn_predictor.is_available() | |
| except Exception: | |
| gnn_ok = False | |
| return jsonify({ | |
| "stats": dict(_stats), | |
| "recent": list(_recent_queries)[:25], | |
| "top_pairs": [{"pair": p, "count": n} for p, n in _pair_counts.most_common(5)], | |
| "system": { | |
| "uptime": f"{h}h {m:02d}m {s:02d}s", | |
| "avg_ms": avg_ms, | |
| "hit_rate": hit_rate, | |
| "dict_loaded": _dict_ready, | |
| "gnn_available": gnn_ok, | |
| "dict_size": len(_ddi_lookup) if _ddi_lookup else 0, | |
| "dict_loaded": _dict_ready, | |
| }, | |
| }) | |
| # --------------------------------------------------------------------------- | |
| # Eager startup — kick off dict/cache loading as soon as the module is imported. | |
| # Works for both `python app.py` and Gunicorn worker processes. | |
| # The thread is daemonised so it never blocks the process from exiting. | |
| # --------------------------------------------------------------------------- | |
| threading.Thread(target=_init_dict, daemon=True).start() | |
| # --------------------------------------------------------------------------- | |
| # Entry point | |
| # --------------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| app.run(host="0.0.0.0", port=5000, debug=False, threaded=False) | |