Spaces:
Running
Running
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| import os | |
| import json | |
| import traceback | |
| import shutil | |
| import typing | |
| # ============================================================ | |
| # Try core pipeline first | |
| # ============================================================ | |
| try: | |
| from core.hybrid_retriever import summarize_combined as core_summarize_combined | |
| CORE_AVAILABLE = True | |
| except Exception: | |
| core_summarize_combined = None | |
| CORE_AVAILABLE = False | |
| # ------------------------------------------------------------ | |
| # Admin functions (safe fallback) | |
| # ------------------------------------------------------------ | |
| try: | |
| from core.admin_tasks import rebuild_index, rebuild_glossary, reset_faiss_cache, clear_index | |
| except Exception: | |
| # fallbacks | |
| def rebuild_index(): return "rebuild_index not available" | |
| def rebuild_glossary(): return "rebuild_glossary not available" | |
| def reset_faiss_cache(): return "reset_faiss_cache not available" | |
| def clear_index(): return "clear_index not available" | |
| # ------------------------------------------------------------ | |
| # Optional FAISS + SentenceTransformer | |
| # ------------------------------------------------------------ | |
| try: | |
| import faiss | |
| from sentence_transformers import SentenceTransformer | |
| EMBEDDER = SentenceTransformer("all-MiniLM-L6-v2") | |
| FAISS_OK = True | |
| except Exception: | |
| EMBEDDER = None | |
| FAISS_OK = False | |
| # TF-IDF fallback | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| from sklearn.metrics.pairwise import linear_kernel | |
| # ------------------------------------------------------------ | |
| # Paths | |
| # ------------------------------------------------------------ | |
| GLOSSARY_PATH = os.environ.get("GLOSSARY_PATH", "./data/glossary.json") | |
| FAISS_INDEX_DIR = os.environ.get("FAISS_INDEX_DIR", "./data/faiss_index") | |
| DOCS_FOLDER = os.environ.get("DOCS_FOLDER", "./data/docs") | |
| ADMIN_PASS = os.environ.get("ADMIN_PASS", "changeme") | |
| DISK_USAGE_THRESHOLD_GB = float(os.environ.get("DISK_USAGE_THRESHOLD_GB", "45.0")) | |
| # ============================================================ | |
| # Disk Utilities | |
| # ============================================================ | |
| def get_folder_size_bytes(path: str) -> int: | |
| total = 0 | |
| if not os.path.exists(path): | |
| return 0 | |
| for root, dirs, files in os.walk(path, onerror=lambda e: None): | |
| for f in files: | |
| fp = os.path.join(root, f) | |
| if os.path.exists(fp): | |
| try: | |
| total += os.path.getsize(fp) | |
| except: | |
| pass | |
| return total | |
| def bytes_to_human(n: int) -> str: | |
| for unit in ["B", "KB", "MB", "GB", "TB"]: | |
| if n < 1024: | |
| return f"{n:.1f}{unit}" | |
| n /= 1024 | |
| return f"{n:.1f}PB" | |
| def get_disk_usage(path="/"): | |
| try: | |
| usage = shutil.disk_usage(path) | |
| return {"total": usage.total, "used": usage.used, "free": usage.free} | |
| except: | |
| try: | |
| st = os.statvfs(path) | |
| total = st.f_frsize * st.f_blocks | |
| free = st.f_frsize * st.f_bfree | |
| used = total - free | |
| return {"total": total, "used": used, "free": free} | |
| except: | |
| return {"total": 0, "used": 0, "free": 0} | |
| # ============================================================ | |
| # Glossary / Docs | |
| # ============================================================ | |
| def load_glossary(): | |
| if not os.path.exists(GLOSSARY_PATH): | |
| return {} | |
| try: | |
| with open(GLOSSARY_PATH, "r", encoding="utf-8") as f: | |
| return json.load(f) | |
| except: | |
| return {} | |
| def load_docs(): | |
| docs = [] | |
| if not os.path.exists(DOCS_FOLDER): | |
| return docs | |
| for f in os.listdir(DOCS_FOLDER): | |
| full = os.path.join(DOCS_FOLDER, f) | |
| if os.path.isfile(full): | |
| try: | |
| docs.append({"id": f, "text": open(full, "r", encoding="utf-8").read()}) | |
| except: | |
| pass | |
| return docs | |
| # ============================================================ | |
| # TF-IDF Retriever | |
| # ============================================================ | |
| class SimpleRetriever: | |
| def __init__(self, docs): | |
| self.docs = docs | |
| texts = [d["text"] for d in docs] | |
| if not texts: | |
| self.vectorizer = None | |
| return | |
| self.vectorizer = TfidfVectorizer(stop_words="english", max_features=4000) | |
| self.mat = self.vectorizer.fit_transform(texts) | |
| def query(self, q, k=3): | |
| if not self.vectorizer: | |
| return [] | |
| qv = self.vectorizer.transform([q]) | |
| sims = linear_kernel(qv, self.mat).flatten() | |
| idxs = sims.argsort()[::-1][:k] | |
| out = [] | |
| for i in idxs: | |
| if sims[i] > 0: | |
| text = self.docs[i]["text"][:300].replace("\n", " ") | |
| out.append({"id": self.docs[i]["id"], "excerpt": text, "score": float(sims[i])}) | |
| return out | |
| # ============================================================ | |
| # FAISS Searcher | |
| # ============================================================ | |
| def load_faiss(): | |
| if not FAISS_OK: | |
| return None | |
| idx_file = os.path.join(FAISS_INDEX_DIR, "index.faiss") | |
| map_file = os.path.join(FAISS_INDEX_DIR, "mapping.json") | |
| if not os.path.exists(idx_file) or not os.path.exists(map_file): | |
| return None | |
| try: | |
| idx = faiss.read_index(idx_file) | |
| mapping = json.load(open(map_file, "r", encoding="utf-8")) | |
| def search(q, k=3): | |
| emb = EMBEDDER.encode([q]) | |
| D, I = idx.search(emb, k) | |
| res = [] | |
| for score, i_id in zip(D[0], I[0]): | |
| meta = mapping.get(str(int(i_id)), {}) | |
| txt = (meta.get("text", "")[:300]).replace("\n", " ") | |
| res.append({ | |
| "id": meta.get("id", i_id), | |
| "excerpt": txt, | |
| "score": float(score) | |
| }) | |
| return res | |
| return search | |
| except: | |
| return None | |
| # ============================================================ | |
| # Summarize Wrapper | |
| # ============================================================ | |
| def fallback_summarize(question): | |
| glossary = load_glossary() | |
| docs = load_docs() | |
| g_hits = [] | |
| for t, d in glossary.items(): | |
| if t.lower() in question.lower(): | |
| g_hits.append({"source": f"glossary:{t}", "excerpt": d[:300]}) | |
| faiss_srch = load_faiss() | |
| doc_hits = faiss_srch(question) if faiss_srch else SimpleRetriever(docs).query(question) | |
| parts = [] | |
| if g_hits: | |
| parts.append("Glossary matches:\n" + "\n".join([f"- {h['source']}: {h['excerpt']}" for h in g_hits])) | |
| if doc_hits: | |
| parts.append("Top documents:\n" + "\n".join([f"- ({d['id']}) {d['excerpt']}" for d in doc_hits])) | |
| if not parts: | |
| return {"answer": f"No sources found for: {question}", "citations": []} | |
| return { | |
| "answer": "\n\n".join(parts), | |
| "citations": g_hits + doc_hits | |
| } | |
| def summarize_combined_wrapper(q): | |
| if CORE_AVAILABLE and core_summarize_combined: | |
| try: | |
| res = core_summarize_combined(q) | |
| if isinstance(res, dict): | |
| return {"answer": res.get("answer", ""), "citations": res.get("citations", [])} | |
| return {"answer": str(res), "citations": []} | |
| except: | |
| traceback.print_exc() | |
| return fallback_summarize(q) | |
| return fallback_summarize(q) | |
| # ============================================================ | |
| # FastAPI - Inner App (CT-Chat API) | |
| # ============================================================ | |
| app = FastAPI(title="CT-Chat API", description="API endpoint for Clinical Trial Chatbot") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], allow_credentials=True, | |
| allow_methods=["*"], allow_headers=["*"] | |
| ) | |
| class Query(BaseModel): | |
| question: str | |
| class AdminPayload(BaseModel): | |
| password: str | |
| force: typing.Optional[bool] = False | |
| # ---------------- Chat Endpoint ---------------- | |
| async def chat(q: Query): | |
| try: | |
| r = summarize_combined_wrapper(q.question) | |
| return {"answer": r["answer"], "citations": r.get("citations", []), "status": "success"} | |
| except Exception as e: | |
| return {"answer": str(e), "citations": [], "status": "error"} | |
| # ============================================================ | |
| # Disk Usage | |
| # ============================================================ | |
| def api_disk_usage(): | |
| usage = get_disk_usage("/") | |
| faiss_size = get_folder_size_bytes(FAISS_INDEX_DIR) | |
| return { | |
| "disk_total_human": bytes_to_human(usage["total"]), | |
| "disk_used_human": bytes_to_human(usage["used"]), | |
| "disk_free_human": bytes_to_human(usage["free"]), | |
| "faiss_index_size": bytes_to_human(faiss_size), | |
| "faiss_index_dir": FAISS_INDEX_DIR, | |
| "threshold_gb": DISK_USAGE_THRESHOLD_GB, | |
| } | |
| # ============================================================ | |
| # Safe Rebuild Index | |
| # ============================================================ | |
| def _check(p: AdminPayload): | |
| if p.password != ADMIN_PASS: | |
| raise HTTPException(status_code=401, detail="Unauthorized") | |
| def admin_safe_rebuild(p: AdminPayload): | |
| _check(p) | |
| usage = get_disk_usage("/") | |
| used_gb = usage["used"] / (1024 ** 3) | |
| if used_gb >= DISK_USAGE_THRESHOLD_GB and not p.force: | |
| return { | |
| "status": "error", | |
| "reason": f"Disk usage {used_gb:.2f}GB is above safety threshold {DISK_USAGE_THRESHOLD_GB}GB. Use force:true to override." | |
| } | |
| try: | |
| if os.path.exists(FAISS_INDEX_DIR): | |
| for f in os.listdir(FAISS_INDEX_DIR): | |
| fp = os.path.join(FAISS_INDEX_DIR, f) | |
| try: | |
| if os.path.isdir(fp): | |
| shutil.rmtree(fp) | |
| else: | |
| os.remove(fp) | |
| except Exception as e: | |
| print(f"Warning: could not delete {fp}: {e}") | |
| else: | |
| os.makedirs(FAISS_INDEX_DIR, exist_ok=True) | |
| except Exception as e: | |
| return {"status": "error", "reason": f"Failed to clear FAISS index folder: {e}"} | |
| try: | |
| res = rebuild_index() | |
| return {"status": "ok", "result": res} | |
| except Exception as e: | |
| traceback.print_exc() | |
| return {"status": "error", "reason": str(e)} | |
| # ============================================================ | |
| # Password Validation | |
| # ============================================================ | |
| def api_validate_password(p: AdminPayload): | |
| if p.password == ADMIN_PASS: | |
| return {"valid": True} | |
| else: | |
| return {"valid": False} | |
| # ============================================================ | |
| # Existing Admin Endpoints | |
| # ============================================================ | |
| def api_rebuild_index(p: AdminPayload): | |
| _check(p) | |
| return {"status": "ok", "result": rebuild_index()} | |
| def api_rebuild_glossary(p: AdminPayload): | |
| _check(p) | |
| return {"status": "ok", "result": rebuild_glossary()} | |
| def api_reset_faiss(p: AdminPayload): | |
| _check(p) | |
| return {"status": "ok", "result": reset_faiss_cache()} | |
| def api_clear_index(p: AdminPayload): | |
| _check(p) | |
| try: | |
| return {"status": "ok", "result": clear_index()} | |
| except Exception as e: | |
| return {"status": "ok", "result": str(e)} | |
| # ============================================================ | |
| # ✔✔ MOUNT API UNDER /api (Fix Android 404) | |
| # ============================================================ | |
| from fastapi import FastAPI as _FastAPI | |
| root_app = _FastAPI(title="Root Server", description="API root router") | |
| root_app.mount("/api", app) | |
| # root_app is now the server entry point | |
| app = root_app | |
| # ============================================================ | |
| # Local Run (now serves root_app correctly) | |
| # ============================================================ | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run("api:app", host="0.0.0.0", port=7861) | |