File size: 12,321 Bytes
12db3c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import os
import json
import traceback
import shutil
import typing

# ============================================================
# Try core pipeline first
# ============================================================
try:
    from core.hybrid_retriever import summarize_combined as core_summarize_combined
    CORE_AVAILABLE = True
except Exception:
    core_summarize_combined = None
    CORE_AVAILABLE = False

# ------------------------------------------------------------
# Admin functions (safe fallback)
# ------------------------------------------------------------
try:
    from core.admin_tasks import rebuild_index, rebuild_glossary, reset_faiss_cache, clear_index
except Exception:
    # fallbacks
    def rebuild_index(): return "rebuild_index not available"
    def rebuild_glossary(): return "rebuild_glossary not available"
    def reset_faiss_cache(): return "reset_faiss_cache not available"
    def clear_index(): return "clear_index not available"

# ------------------------------------------------------------
# Optional FAISS + SentenceTransformer
# ------------------------------------------------------------
try:
    import faiss
    from sentence_transformers import SentenceTransformer
    EMBEDDER = SentenceTransformer("all-MiniLM-L6-v2")
    FAISS_OK = True
except Exception:
    EMBEDDER = None
    FAISS_OK = False

# TF-IDF fallback
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import linear_kernel

# ------------------------------------------------------------
# Paths
# ------------------------------------------------------------
GLOSSARY_PATH = os.environ.get("GLOSSARY_PATH", "./data/glossary.json")
FAISS_INDEX_DIR = os.environ.get("FAISS_INDEX_DIR", "./data/faiss_index")
DOCS_FOLDER = os.environ.get("DOCS_FOLDER", "./data/docs")
ADMIN_PASS = os.environ.get("ADMIN_PASS", "changeme")

DISK_USAGE_THRESHOLD_GB = float(os.environ.get("DISK_USAGE_THRESHOLD_GB", "45.0"))

# ============================================================
# Disk Utilities
# ============================================================
def get_folder_size_bytes(path: str) -> int:
    total = 0
    if not os.path.exists(path):
        return 0
    for root, dirs, files in os.walk(path, onerror=lambda e: None):
        for f in files:
            fp = os.path.join(root, f)
            if os.path.exists(fp):
                try:
                    total += os.path.getsize(fp)
                except:
                    pass
    return total

def bytes_to_human(n: int) -> str:
    for unit in ["B", "KB", "MB", "GB", "TB"]:
        if n < 1024:
            return f"{n:.1f}{unit}"
        n /= 1024
    return f"{n:.1f}PB"

def get_disk_usage(path="/"):
    try:
        usage = shutil.disk_usage(path)
        return {"total": usage.total, "used": usage.used, "free": usage.free}
    except:
        try:
            st = os.statvfs(path)
            total = st.f_frsize * st.f_blocks
            free = st.f_frsize * st.f_bfree
            used = total - free
            return {"total": total, "used": used, "free": free}
        except:
            return {"total": 0, "used": 0, "free": 0}

# ============================================================
# Glossary / Docs
# ============================================================
def load_glossary():
    if not os.path.exists(GLOSSARY_PATH):
        return {}
    try:
        with open(GLOSSARY_PATH, "r", encoding="utf-8") as f:
            return json.load(f)
    except:
        return {}

def load_docs():
    docs = []
    if not os.path.exists(DOCS_FOLDER):
        return docs
    for f in os.listdir(DOCS_FOLDER):
        full = os.path.join(DOCS_FOLDER, f)
        if os.path.isfile(full):
            try:
                docs.append({"id": f, "text": open(full, "r", encoding="utf-8").read()})
            except:
                pass
    return docs

# ============================================================
# TF-IDF Retriever
# ============================================================
class SimpleRetriever:
    def __init__(self, docs):
        self.docs = docs
        texts = [d["text"] for d in docs]
        if not texts:
            self.vectorizer = None
            return
        self.vectorizer = TfidfVectorizer(stop_words="english", max_features=4000)
        self.mat = self.vectorizer.fit_transform(texts)

    def query(self, q, k=3):
        if not self.vectorizer:
            return []
        qv = self.vectorizer.transform([q])
        sims = linear_kernel(qv, self.mat).flatten()
        idxs = sims.argsort()[::-1][:k]
        out = []
        for i in idxs:
            if sims[i] > 0:
                text = self.docs[i]["text"][:300].replace("\n", " ")
                out.append({"id": self.docs[i]["id"], "excerpt": text, "score": float(sims[i])})
        return out

# ============================================================
# FAISS Searcher
# ============================================================
def load_faiss():
    if not FAISS_OK:
        return None
    idx_file = os.path.join(FAISS_INDEX_DIR, "index.faiss")
    map_file = os.path.join(FAISS_INDEX_DIR, "mapping.json")
    if not os.path.exists(idx_file) or not os.path.exists(map_file):
        return None
    try:
        idx = faiss.read_index(idx_file)
        mapping = json.load(open(map_file, "r", encoding="utf-8"))

        def search(q, k=3):
            emb = EMBEDDER.encode([q])
            D, I = idx.search(emb, k)
            res = []
            for score, i_id in zip(D[0], I[0]):
                meta = mapping.get(str(int(i_id)), {})
                txt = (meta.get("text", "")[:300]).replace("\n", " ")
                res.append({
                    "id": meta.get("id", i_id),
                    "excerpt": txt,
                    "score": float(score)
                })
            return res
        return search
    except:
        return None

# ============================================================
# Summarize Wrapper
# ============================================================
def fallback_summarize(question):
    glossary = load_glossary()
    docs = load_docs()

    g_hits = []
    for t, d in glossary.items():
        if t.lower() in question.lower():
            g_hits.append({"source": f"glossary:{t}", "excerpt": d[:300]})

    faiss_srch = load_faiss()
    doc_hits = faiss_srch(question) if faiss_srch else SimpleRetriever(docs).query(question)

    parts = []
    if g_hits:
        parts.append("Glossary matches:\n" + "\n".join([f"- {h['source']}: {h['excerpt']}" for h in g_hits]))
    if doc_hits:
        parts.append("Top documents:\n" + "\n".join([f"- ({d['id']}) {d['excerpt']}" for d in doc_hits]))

    if not parts:
        return {"answer": f"No sources found for: {question}", "citations": []}

    return {
        "answer": "\n\n".join(parts),
        "citations": g_hits + doc_hits
    }

def summarize_combined_wrapper(q):
    if CORE_AVAILABLE and core_summarize_combined:
        try:
            res = core_summarize_combined(q)
            if isinstance(res, dict):
                return {"answer": res.get("answer", ""), "citations": res.get("citations", [])}
            return {"answer": str(res), "citations": []}
        except:
            traceback.print_exc()
            return fallback_summarize(q)
    return fallback_summarize(q)

# ============================================================
# FastAPI - Inner App (CT-Chat API)
# ============================================================
app = FastAPI(title="CT-Chat API", description="API endpoint for Clinical Trial Chatbot")

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"], allow_credentials=True,
    allow_methods=["*"], allow_headers=["*"]
)

class Query(BaseModel):
    question: str

class AdminPayload(BaseModel):
    password: str
    force: typing.Optional[bool] = False

# ---------------- Chat Endpoint ----------------
@app.post("/chat")
async def chat(q: Query):
    try:
        r = summarize_combined_wrapper(q.question)
        return {"answer": r["answer"], "citations": r.get("citations", []), "status": "success"}
    except Exception as e:
        return {"answer": str(e), "citations": [], "status": "error"}

# ============================================================
# Disk Usage
# ============================================================
@app.get("/admin/disk_usage")
def api_disk_usage():
    usage = get_disk_usage("/")
    faiss_size = get_folder_size_bytes(FAISS_INDEX_DIR)
    return {
        "disk_total_human": bytes_to_human(usage["total"]),
        "disk_used_human": bytes_to_human(usage["used"]),
        "disk_free_human": bytes_to_human(usage["free"]),
        "faiss_index_size": bytes_to_human(faiss_size),
        "faiss_index_dir": FAISS_INDEX_DIR,
        "threshold_gb": DISK_USAGE_THRESHOLD_GB,
    }

# ============================================================
# Safe Rebuild Index
# ============================================================
def _check(p: AdminPayload):
    if p.password != ADMIN_PASS:
        raise HTTPException(status_code=401, detail="Unauthorized")

@app.post("/admin/safe_rebuild_index")
def admin_safe_rebuild(p: AdminPayload):
    _check(p)
    usage = get_disk_usage("/")
    used_gb = usage["used"] / (1024 ** 3)

    if used_gb >= DISK_USAGE_THRESHOLD_GB and not p.force:
        return {
            "status": "error",
            "reason": f"Disk usage {used_gb:.2f}GB is above safety threshold {DISK_USAGE_THRESHOLD_GB}GB. Use force:true to override."
        }

    try:
        if os.path.exists(FAISS_INDEX_DIR):
            for f in os.listdir(FAISS_INDEX_DIR):
                fp = os.path.join(FAISS_INDEX_DIR, f)
                try:
                    if os.path.isdir(fp):
                        shutil.rmtree(fp)
                    else:
                        os.remove(fp)
                except Exception as e:
                    print(f"Warning: could not delete {fp}: {e}")
        else:
            os.makedirs(FAISS_INDEX_DIR, exist_ok=True)
    except Exception as e:
        return {"status": "error", "reason": f"Failed to clear FAISS index folder: {e}"}

    try:
        res = rebuild_index()
        return {"status": "ok", "result": res}
    except Exception as e:
        traceback.print_exc()
        return {"status": "error", "reason": str(e)}

# ============================================================
# Password Validation
# ============================================================
@app.post("/admin/validate_password")
def api_validate_password(p: AdminPayload):
    if p.password == ADMIN_PASS:
        return {"valid": True}
    else:
        return {"valid": False}

# ============================================================
# Existing Admin Endpoints
# ============================================================
@app.post("/admin/rebuild_index")
def api_rebuild_index(p: AdminPayload):
    _check(p)
    return {"status": "ok", "result": rebuild_index()}

@app.post("/admin/rebuild_glossary")
def api_rebuild_glossary(p: AdminPayload):
    _check(p)
    return {"status": "ok", "result": rebuild_glossary()}

@app.post("/admin/reset_faiss")
def api_reset_faiss(p: AdminPayload):
    _check(p)
    return {"status": "ok", "result": reset_faiss_cache()}

@app.post("/admin/clear_index")
def api_clear_index(p: AdminPayload):
    _check(p)
    try:
        return {"status": "ok", "result": clear_index()}
    except Exception as e:
        return {"status": "ok", "result": str(e)}

# ============================================================
# ✔✔ MOUNT API UNDER /api (Fix Android 404)
# ============================================================
from fastapi import FastAPI as _FastAPI

root_app = _FastAPI(title="Root Server", description="API root router")
root_app.mount("/api", app)

# root_app is now the server entry point
app = root_app

# ============================================================
# Local Run (now serves root_app correctly)
# ============================================================
if __name__ == "__main__":
    import uvicorn
    uvicorn.run("api:app", host="0.0.0.0", port=7861)