CT-Chat-V2 / api.py
essprasad's picture
Upload 3 files
12db3c3 verified
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import os
import json
import traceback
import shutil
import typing
# ============================================================
# Try core pipeline first
# ============================================================
try:
from core.hybrid_retriever import summarize_combined as core_summarize_combined
CORE_AVAILABLE = True
except Exception:
core_summarize_combined = None
CORE_AVAILABLE = False
# ------------------------------------------------------------
# Admin functions (safe fallback)
# ------------------------------------------------------------
try:
from core.admin_tasks import rebuild_index, rebuild_glossary, reset_faiss_cache, clear_index
except Exception:
# fallbacks
def rebuild_index(): return "rebuild_index not available"
def rebuild_glossary(): return "rebuild_glossary not available"
def reset_faiss_cache(): return "reset_faiss_cache not available"
def clear_index(): return "clear_index not available"
# ------------------------------------------------------------
# Optional FAISS + SentenceTransformer
# ------------------------------------------------------------
try:
import faiss
from sentence_transformers import SentenceTransformer
EMBEDDER = SentenceTransformer("all-MiniLM-L6-v2")
FAISS_OK = True
except Exception:
EMBEDDER = None
FAISS_OK = False
# TF-IDF fallback
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import linear_kernel
# ------------------------------------------------------------
# Paths
# ------------------------------------------------------------
GLOSSARY_PATH = os.environ.get("GLOSSARY_PATH", "./data/glossary.json")
FAISS_INDEX_DIR = os.environ.get("FAISS_INDEX_DIR", "./data/faiss_index")
DOCS_FOLDER = os.environ.get("DOCS_FOLDER", "./data/docs")
ADMIN_PASS = os.environ.get("ADMIN_PASS", "changeme")
DISK_USAGE_THRESHOLD_GB = float(os.environ.get("DISK_USAGE_THRESHOLD_GB", "45.0"))
# ============================================================
# Disk Utilities
# ============================================================
def get_folder_size_bytes(path: str) -> int:
total = 0
if not os.path.exists(path):
return 0
for root, dirs, files in os.walk(path, onerror=lambda e: None):
for f in files:
fp = os.path.join(root, f)
if os.path.exists(fp):
try:
total += os.path.getsize(fp)
except:
pass
return total
def bytes_to_human(n: int) -> str:
for unit in ["B", "KB", "MB", "GB", "TB"]:
if n < 1024:
return f"{n:.1f}{unit}"
n /= 1024
return f"{n:.1f}PB"
def get_disk_usage(path="/"):
try:
usage = shutil.disk_usage(path)
return {"total": usage.total, "used": usage.used, "free": usage.free}
except:
try:
st = os.statvfs(path)
total = st.f_frsize * st.f_blocks
free = st.f_frsize * st.f_bfree
used = total - free
return {"total": total, "used": used, "free": free}
except:
return {"total": 0, "used": 0, "free": 0}
# ============================================================
# Glossary / Docs
# ============================================================
def load_glossary():
if not os.path.exists(GLOSSARY_PATH):
return {}
try:
with open(GLOSSARY_PATH, "r", encoding="utf-8") as f:
return json.load(f)
except:
return {}
def load_docs():
docs = []
if not os.path.exists(DOCS_FOLDER):
return docs
for f in os.listdir(DOCS_FOLDER):
full = os.path.join(DOCS_FOLDER, f)
if os.path.isfile(full):
try:
docs.append({"id": f, "text": open(full, "r", encoding="utf-8").read()})
except:
pass
return docs
# ============================================================
# TF-IDF Retriever
# ============================================================
class SimpleRetriever:
def __init__(self, docs):
self.docs = docs
texts = [d["text"] for d in docs]
if not texts:
self.vectorizer = None
return
self.vectorizer = TfidfVectorizer(stop_words="english", max_features=4000)
self.mat = self.vectorizer.fit_transform(texts)
def query(self, q, k=3):
if not self.vectorizer:
return []
qv = self.vectorizer.transform([q])
sims = linear_kernel(qv, self.mat).flatten()
idxs = sims.argsort()[::-1][:k]
out = []
for i in idxs:
if sims[i] > 0:
text = self.docs[i]["text"][:300].replace("\n", " ")
out.append({"id": self.docs[i]["id"], "excerpt": text, "score": float(sims[i])})
return out
# ============================================================
# FAISS Searcher
# ============================================================
def load_faiss():
if not FAISS_OK:
return None
idx_file = os.path.join(FAISS_INDEX_DIR, "index.faiss")
map_file = os.path.join(FAISS_INDEX_DIR, "mapping.json")
if not os.path.exists(idx_file) or not os.path.exists(map_file):
return None
try:
idx = faiss.read_index(idx_file)
mapping = json.load(open(map_file, "r", encoding="utf-8"))
def search(q, k=3):
emb = EMBEDDER.encode([q])
D, I = idx.search(emb, k)
res = []
for score, i_id in zip(D[0], I[0]):
meta = mapping.get(str(int(i_id)), {})
txt = (meta.get("text", "")[:300]).replace("\n", " ")
res.append({
"id": meta.get("id", i_id),
"excerpt": txt,
"score": float(score)
})
return res
return search
except:
return None
# ============================================================
# Summarize Wrapper
# ============================================================
def fallback_summarize(question):
glossary = load_glossary()
docs = load_docs()
g_hits = []
for t, d in glossary.items():
if t.lower() in question.lower():
g_hits.append({"source": f"glossary:{t}", "excerpt": d[:300]})
faiss_srch = load_faiss()
doc_hits = faiss_srch(question) if faiss_srch else SimpleRetriever(docs).query(question)
parts = []
if g_hits:
parts.append("Glossary matches:\n" + "\n".join([f"- {h['source']}: {h['excerpt']}" for h in g_hits]))
if doc_hits:
parts.append("Top documents:\n" + "\n".join([f"- ({d['id']}) {d['excerpt']}" for d in doc_hits]))
if not parts:
return {"answer": f"No sources found for: {question}", "citations": []}
return {
"answer": "\n\n".join(parts),
"citations": g_hits + doc_hits
}
def summarize_combined_wrapper(q):
if CORE_AVAILABLE and core_summarize_combined:
try:
res = core_summarize_combined(q)
if isinstance(res, dict):
return {"answer": res.get("answer", ""), "citations": res.get("citations", [])}
return {"answer": str(res), "citations": []}
except:
traceback.print_exc()
return fallback_summarize(q)
return fallback_summarize(q)
# ============================================================
# FastAPI - Inner App (CT-Chat API)
# ============================================================
app = FastAPI(title="CT-Chat API", description="API endpoint for Clinical Trial Chatbot")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], allow_credentials=True,
allow_methods=["*"], allow_headers=["*"]
)
class Query(BaseModel):
question: str
class AdminPayload(BaseModel):
password: str
force: typing.Optional[bool] = False
# ---------------- Chat Endpoint ----------------
@app.post("/chat")
async def chat(q: Query):
try:
r = summarize_combined_wrapper(q.question)
return {"answer": r["answer"], "citations": r.get("citations", []), "status": "success"}
except Exception as e:
return {"answer": str(e), "citations": [], "status": "error"}
# ============================================================
# Disk Usage
# ============================================================
@app.get("/admin/disk_usage")
def api_disk_usage():
usage = get_disk_usage("/")
faiss_size = get_folder_size_bytes(FAISS_INDEX_DIR)
return {
"disk_total_human": bytes_to_human(usage["total"]),
"disk_used_human": bytes_to_human(usage["used"]),
"disk_free_human": bytes_to_human(usage["free"]),
"faiss_index_size": bytes_to_human(faiss_size),
"faiss_index_dir": FAISS_INDEX_DIR,
"threshold_gb": DISK_USAGE_THRESHOLD_GB,
}
# ============================================================
# Safe Rebuild Index
# ============================================================
def _check(p: AdminPayload):
if p.password != ADMIN_PASS:
raise HTTPException(status_code=401, detail="Unauthorized")
@app.post("/admin/safe_rebuild_index")
def admin_safe_rebuild(p: AdminPayload):
_check(p)
usage = get_disk_usage("/")
used_gb = usage["used"] / (1024 ** 3)
if used_gb >= DISK_USAGE_THRESHOLD_GB and not p.force:
return {
"status": "error",
"reason": f"Disk usage {used_gb:.2f}GB is above safety threshold {DISK_USAGE_THRESHOLD_GB}GB. Use force:true to override."
}
try:
if os.path.exists(FAISS_INDEX_DIR):
for f in os.listdir(FAISS_INDEX_DIR):
fp = os.path.join(FAISS_INDEX_DIR, f)
try:
if os.path.isdir(fp):
shutil.rmtree(fp)
else:
os.remove(fp)
except Exception as e:
print(f"Warning: could not delete {fp}: {e}")
else:
os.makedirs(FAISS_INDEX_DIR, exist_ok=True)
except Exception as e:
return {"status": "error", "reason": f"Failed to clear FAISS index folder: {e}"}
try:
res = rebuild_index()
return {"status": "ok", "result": res}
except Exception as e:
traceback.print_exc()
return {"status": "error", "reason": str(e)}
# ============================================================
# Password Validation
# ============================================================
@app.post("/admin/validate_password")
def api_validate_password(p: AdminPayload):
if p.password == ADMIN_PASS:
return {"valid": True}
else:
return {"valid": False}
# ============================================================
# Existing Admin Endpoints
# ============================================================
@app.post("/admin/rebuild_index")
def api_rebuild_index(p: AdminPayload):
_check(p)
return {"status": "ok", "result": rebuild_index()}
@app.post("/admin/rebuild_glossary")
def api_rebuild_glossary(p: AdminPayload):
_check(p)
return {"status": "ok", "result": rebuild_glossary()}
@app.post("/admin/reset_faiss")
def api_reset_faiss(p: AdminPayload):
_check(p)
return {"status": "ok", "result": reset_faiss_cache()}
@app.post("/admin/clear_index")
def api_clear_index(p: AdminPayload):
_check(p)
try:
return {"status": "ok", "result": clear_index()}
except Exception as e:
return {"status": "ok", "result": str(e)}
# ============================================================
# ✔✔ MOUNT API UNDER /api (Fix Android 404)
# ============================================================
from fastapi import FastAPI as _FastAPI
root_app = _FastAPI(title="Root Server", description="API root router")
root_app.mount("/api", app)
# root_app is now the server entry point
app = root_app
# ============================================================
# Local Run (now serves root_app correctly)
# ============================================================
if __name__ == "__main__":
import uvicorn
uvicorn.run("api:app", host="0.0.0.0", port=7861)