atkiya110 commited on
Commit
34ca45d
Β·
verified Β·
1 Parent(s): 5a3b8ff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +247 -640
app.py CHANGED
@@ -1,59 +1,15 @@
1
- """
2
- EWU RAG Server β€” v2.1 (Fast-Restart Edition)
3
- ══════════════════════════════════════════════════════════════════════
4
- Key fixes over v2
5
- ─────────────────
6
- 1. DISK CACHE β€” chunks, FAISS index, BM25, KG, entity_index and doc
7
- embeddings are all saved to ./cache/ on first boot. Subsequent
8
- restarts load from disk in ~5-10s instead of ~5 minutes.
9
-
10
- 2. PRIORITY BOOT ORDER
11
- Phase 1 (blocking) : load embedder + cross-encoder β†’ build/load
12
- all indexes β†’ server returns 200 immediately.
13
- Phase 2 (background): load TinyLlama in a background task.
14
- /rag returns context-only answer until
15
- the generator is ready, then full answer.
16
-
17
- 3. DETAIL-ENDPOINT CIRCUIT BREAKER
18
- The API detail pages (/faculty/<id>, /programs/<id>, /documents/<slug>)
19
- all returned HTTP 500 in the logs. We now track failures per URL
20
- and skip detail fetches after 3 consecutive 500s to avoid wasting
21
- 30+ seconds on guaranteed failures every boot.
22
-
23
- 4. HyDE DISABLED ON CPU
24
- HyDE costs one full TinyLlama forward pass per query. On CPU that
25
- is ~15-30 seconds of extra latency with minimal accuracy gain.
26
- Set ENABLE_HYDE = True if you have a GPU.
27
-
28
- 5. GRACEFUL DEGRADATION
29
- Every component (embedder, reranker, generator, FAISS, BM25, KG)
30
- is independent. The server works β€” at reduced quality β€” if any
31
- single component fails to load.
32
- """
33
-
34
- import asyncio
35
- import json
36
- import logging
37
  import os
38
- import pickle
39
- import re
40
- import string
41
- import time
42
- from contextlib import asynccontextmanager
43
- from typing import Any, Dict, List, Optional
44
-
45
- import httpx
46
  import numpy as np
47
  import uvicorn
 
 
 
48
  from fastapi import FastAPI, HTTPException
49
  from fastapi.responses import JSONResponse
50
  from pydantic import BaseModel
51
 
52
- logging.basicConfig(level=logging.INFO,
53
- format="%(asctime)s %(levelname)s %(message)s")
54
- logger = logging.getLogger(__name__)
55
-
56
- # ── optional heavy deps ───────────────────────────────────────────────────────
57
  try:
58
  import faiss
59
  FAISS_OK = True
@@ -61,7 +17,7 @@ except ImportError:
61
  FAISS_OK = False
62
 
63
  try:
64
- from sentence_transformers import SentenceTransformer, CrossEncoder
65
  ST_OK = True
66
  except ImportError:
67
  ST_OK = False
@@ -77,49 +33,29 @@ try:
77
  HF_OK = True
78
  except ImportError:
79
  HF_OK = False
 
80
 
81
- try:
82
- import networkx as nx
83
- NX_OK = True
84
- except ImportError:
85
- NX_OK = False
86
-
87
- DEVICE = "cpu"
88
- try:
89
- import torch
90
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
91
- except ImportError:
92
- pass
93
-
94
- # ═════════════════════════════════════════════════════════════════════════════
95
  # CONFIG
96
- # ═════════════════════════════════════════════════════════════════════════════
97
 
98
  API_BASE = "https://ewu-server.onrender.com/api"
99
  API_KEY = "i6EDytaX4E2jI6GvZQc0b1RSZHTI5_wVRa2rfL7rLpk"
100
  API_HEADERS = {"x-api-key": API_KEY}
101
- GITHUB_BASE = "https://raw.githubusercontent.com/Atkiya/jsonfiles/main/"
102
-
103
- EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
104
- RERANK_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"
105
- GEN_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
106
 
107
- CHUNK_SIZE = 512
108
- CHUNK_OVERLAP = 100
109
- TOP_K_RETRIEVE = 20
110
- TOP_K_FINAL = 5
111
- RERANK_THRESH = -5.0
112
- MMR_LAMBDA = 0.65
113
- COMPRESS_THRESH= 0.25
114
 
115
- # ── Performance switches ──────────────────────────────────────────────────────
116
- ENABLE_HYDE = (DEVICE == "cuda") # disabled on CPU β€” too slow
117
- CACHE_DIR = "./cache"
118
- CACHE_TTL_H = 24 # rebuild disk cache every 24 h
119
 
120
- # ── API circuit-breaker ───────────────────────────────────────────────────────
121
- _api_fail_count: Dict[str, int] = {}
122
- API_FAIL_LIMIT = 3
 
 
123
 
124
  API_LIST_ENDPOINTS = [
125
  "admission-deadlines", "academic-calendar", "grade-scale",
@@ -127,11 +63,13 @@ API_LIST_ENDPOINTS = [
127
  "governance", "alumni", "helpdesk", "policies", "proctor-schedule",
128
  "documents", "newsletters", "programs", "faculty", "departments",
129
  ]
 
130
  API_DETAIL_ENDPOINTS = [
131
  {"list": "programs", "id_field": "id"},
132
  {"list": "faculty", "id_field": "id"},
133
  {"list": "documents", "id_field": "slug"},
134
  ]
 
135
  GITHUB_FILES = [
136
  "admission_deadlines.json", "dynamic_admission_process.json",
137
  "dynamic_admission_requirements.json", "dynamic_tution_fees.json",
@@ -141,8 +79,7 @@ GITHUB_FILES = [
141
  "static_aboutEWU.json", "static_Admin.json",
142
  "static_AllAvailablePrograms.json", "static_alumni.json",
143
  "static_campus_life.json", "static_Career_Counseling_Center.json",
144
- "static_clubs.json", "static_depts.json",
145
- # "static_facilities.json", <- 404, skipped
146
  "static_helpdesk.json", "static_payment_procedure.json",
147
  "static_Policy.json", "static_Programs.json", "static_Rules.json",
148
  "static_Sexual_harassment.json", "static_Tuition_fees.json",
@@ -156,307 +93,158 @@ GITHUB_FILES = [
156
  "syndicate.json", "tesol.json", "ewu_board_of_trustees.json",
157
  ]
158
 
159
- # ═════════════════════════════════════════════════════════════════════════════
160
  # APP STATE
161
- # ═════════════════════════════════════════════════════════════════════════════
162
 
163
  class AppState:
164
- embedder = None
165
- reranker = None
166
- generator = None
167
- documents : List[Dict] = []
168
- faiss_index = None
169
- doc_embeddings: Optional[np.ndarray] = None
170
- bm25 = None
171
- kg = None
172
- entity_index : Dict[str, List[int]] = {}
173
- ready : bool = False
174
- gen_ready : bool = False
175
- error : str = ""
176
 
177
  state = AppState()
178
 
179
- # ═════════════════════════════════════════════════════════════════════════════
180
- # DISK CACHE HELPERS
181
- # ═════════════════════════════════════════════════════════════════════════════
182
-
183
- os.makedirs(CACHE_DIR, exist_ok=True)
184
-
185
- def _cp(name: str) -> str:
186
- return os.path.join(CACHE_DIR, name)
187
-
188
- def _cache_fresh(name: str) -> bool:
189
- p = _cp(name)
190
- if not os.path.exists(p):
191
- return False
192
- return (time.time() - os.path.getmtime(p)) / 3600 < CACHE_TTL_H
193
-
194
- def _save(name: str, obj: Any) -> None:
195
- try:
196
- with open(_cp(name), "wb") as f:
197
- pickle.dump(obj, f, protocol=5)
198
- logger.info(f"[cache] saved {name}")
199
- except Exception as e:
200
- logger.warning(f"[cache] save {name} failed: {e}")
201
-
202
- def _load(name: str) -> Optional[Any]:
203
- try:
204
- with open(_cp(name), "rb") as f:
205
- return pickle.load(f)
206
- except Exception as e:
207
- logger.warning(f"[cache] load {name} failed: {e}")
208
- return None
209
-
210
- def _save_faiss(idx) -> None:
211
- try:
212
- faiss.write_index(idx, _cp("faiss.index"))
213
- logger.info("[cache] saved faiss.index")
214
- except Exception as e:
215
- logger.warning(f"[cache] faiss save failed: {e}")
216
-
217
- def _load_faiss():
218
- p = _cp("faiss.index")
219
- if not os.path.exists(p):
220
- return None
221
- try:
222
- idx = faiss.read_index(p)
223
- logger.info(f"[cache] loaded faiss.index ({idx.ntotal} vectors)")
224
- return idx
225
- except Exception as e:
226
- logger.warning(f"[cache] faiss load failed: {e}")
227
- return None
228
-
229
- # ═════════════════════════════════════════════════════════════════════════════
230
  # DATA LOADING
231
- # ═════════════════════════════════════════════════════════════════════════════
232
 
233
- async def _fetch(url: str, headers: dict = None, timeout: int = 60) -> Optional[Any]:
234
- key = url.split("?")[0]
235
- if _api_fail_count.get(key, 0) >= API_FAIL_LIMIT:
236
- return None
237
  try:
238
  async with httpx.AsyncClient(timeout=timeout) as client:
239
  r = await client.get(url, headers=headers or {})
240
  if r.status_code == 200:
241
- _api_fail_count[key] = 0
242
  return r.json()
243
- if r.status_code in (404, 500, 502, 503):
244
- _api_fail_count[key] = _api_fail_count.get(key, 0) + 1
245
- logger.warning(f"[WARN] {url} β†’ HTTP {r.status_code} "
246
- f"(fail #{_api_fail_count[key]})")
247
  except Exception as e:
248
- _api_fail_count[key] = _api_fail_count.get(key, 0) + 1
249
- logger.warning(f"[WARN] {url} β†’ {e}")
250
  return None
251
 
252
 
253
- def _unwrap(data: Any) -> list:
254
- if isinstance(data, list): return data
 
255
  if isinstance(data, dict):
256
- for k in ("data", "results", "items"):
257
- if k in data and isinstance(data[k], list):
258
- return data[k]
259
  return [data]
260
  return []
261
 
262
 
263
- async def _wake_api() -> bool:
264
- logger.info("[API] Waking render.com server…")
 
 
 
 
265
  for attempt in range(3):
266
- if await _fetch(f"{API_BASE}/grade-scale", API_HEADERS, timeout=60):
267
- logger.info("[API] Server awake.")
 
268
  return True
269
- logger.info(f"[API] Wake attempt {attempt+1}/3 failed…")
270
  await asyncio.sleep(10)
271
- logger.warning("[API] Server did not wake β€” skipping API data.")
272
  return False
273
 
274
 
275
- async def load_api() -> List[Dict]:
276
- if not await _wake_api():
 
277
  return []
278
 
279
  list_results = await asyncio.gather(
280
- *[_fetch(f"{API_BASE}/{ep}", API_HEADERS) for ep in API_LIST_ENDPOINTS],
281
  return_exceptions=True,
282
  )
283
  docs, list_cache = [], {}
284
  for ep, data in zip(API_LIST_ENDPOINTS, list_results):
285
- if not data or isinstance(data, Exception): continue
 
286
  items = _unwrap(data)
287
  list_cache[ep] = items
288
  for item in items:
289
  text = json.dumps(item, ensure_ascii=False)
290
  if text.strip():
291
  docs.append({"content": text, "source": f"api:{ep}"})
292
- logger.info(f"[API lists] {len(docs)} docs")
293
 
294
- # Detail pages β€” bounded concurrency, circuit-breaker aware
295
  detail_tasks = []
296
  for cfg in API_DETAIL_ENDPOINTS:
297
  for item in list_cache.get(cfg["list"], []):
298
  item_id = item.get(cfg["id_field"]) if isinstance(item, dict) else None
299
- if item_id is None: continue
300
- url = f"{API_BASE}/{cfg['list']}/{item_id}"
301
- if _api_fail_count.get(url, 0) < API_FAIL_LIMIT:
302
  detail_tasks.append((url, f"api:{cfg['list']}/{item_id}"))
303
 
304
  if detail_tasks:
305
- sem = asyncio.Semaphore(5)
306
- async def _bounded(url, source):
307
- async with sem:
308
- return await _fetch(url, API_HEADERS), source
309
-
310
- results = await asyncio.gather(
311
- *[_bounded(u, s) for u, s in detail_tasks],
312
  return_exceptions=True,
313
  )
314
  n = 0
315
- for res in results:
316
- if isinstance(res, Exception): continue
317
- data, source = res
318
- if not data: continue
319
  for item in _unwrap(data):
320
  text = json.dumps(item, ensure_ascii=False)
321
  if text.strip():
322
  docs.append({"content": text, "source": source})
323
  n += 1
324
- logger.info(f"[API details] {n} docs")
325
 
326
- logger.info(f"[API total] {len(docs)} raw docs")
327
  return docs
328
 
329
 
330
- async def load_github() -> List[Dict]:
331
  responses = await asyncio.gather(
332
- *[_fetch(GITHUB_BASE + f) for f in GITHUB_FILES],
333
  return_exceptions=True,
334
  )
335
  docs = []
336
  for fname, data in zip(GITHUB_FILES, responses):
337
- if not data or isinstance(data, Exception): continue
 
338
  for item in (data if isinstance(data, list) else [data]):
339
  text = json.dumps(item, ensure_ascii=False)
340
  if text.strip():
341
  docs.append({"content": text, "source": f"github:{fname}"})
342
- logger.info(f"[GitHub] {len(docs)} raw docs")
343
  return docs
344
 
345
- # ═════════════════════════════════════════════════════════════════════════════
346
  # CHUNKING
347
- # ═════════════════════════════════════════════════════════════════════════════
348
-
349
- def _flatten_json(obj, path="", sep=" > ") -> List[str]:
350
- lines = []
351
- if isinstance(obj, dict):
352
- for k, v in obj.items():
353
- np_ = f"{path}{sep}{k}" if path else k
354
- if isinstance(v, (dict, list)):
355
- lines.extend(_flatten_json(v, np_, sep))
356
- else:
357
- val = str(v).strip()
358
- if val and val.lower() not in ("null", "none", "", "[]", "{}"):
359
- lines.append(f"{np_}: {val}")
360
- elif isinstance(obj, list):
361
- for i, item in enumerate(obj):
362
- if isinstance(item, (dict, list)):
363
- lines.extend(_flatten_json(item, f"{path}[{i}]", sep))
364
- else:
365
- val = str(item).strip()
366
- if val: lines.append(f"{path}[{i}]: {val}")
367
- return lines
368
-
369
-
370
- def _json_chunks(text: str, source: str) -> List[Dict]:
371
- try:
372
- obj = json.loads(text)
373
- lines = _flatten_json(obj)
374
- if not lines: return []
375
- chunks, buf, length = [], [], 0
376
- for line in lines:
377
- if length + len(line) + 1 > CHUNK_SIZE and buf:
378
- chunks.append(" | ".join(buf))
379
- keep = max(1, len(buf) // 5)
380
- buf = buf[-keep:]
381
- length = sum(len(l) + 1 for l in buf)
382
- buf.append(line); length += len(line) + 1
383
- if buf: chunks.append(" | ".join(buf))
384
- return [{"content": c, "source": source} for c in chunks if c.strip()]
385
- except Exception:
386
- return []
387
 
388
-
389
- def chunk_documents(docs: List[Dict]) -> List[Dict]:
390
- step, out = max(1, CHUNK_SIZE - CHUNK_OVERLAP), []
391
  for d in docs:
392
- text, source = d["content"], d["source"]
393
- if not text.strip(): continue
394
- jc = _json_chunks(text, source)
395
- if jc: out.extend(jc); continue
396
- if len(text) <= CHUNK_SIZE: out.append(d); continue
 
397
  start = 0
398
  while start < len(text):
399
- chunk = text[start:start + CHUNK_SIZE]
400
- if chunk.strip(): out.append({"content": chunk, "source": source})
 
401
  start += step
402
  return out
403
 
404
- # ═════════════════════════════════════════════════════════════════════════════
405
- # KNOWLEDGE GRAPH
406
- # ═════════════════════════════════════════════════════════════════════════════
407
-
408
- _STOP = set(string.punctuation) | {
409
- "the","a","an","is","are","was","were","of","in","at","to","for",
410
- "and","or","not","this","that","it","its","with","as","by","on",
411
- "from","all","be","been","has","have","had","will","would","can",
412
- "could","do","does","did","he","she","they","we","you","i","me",
413
- }
414
-
415
-
416
- def build_knowledge_graph(docs: List[Dict]):
417
- if not NX_OK: return None, {}
418
- G = nx.DiGraph()
419
- entity_index: Dict[str, List[int]] = {}
420
- for ci, doc in enumerate(docs):
421
- for line in doc["content"].split(" | "):
422
- parts = line.split(": ", 1)
423
- if len(parts) != 2: continue
424
- key, val = parts[0].strip().lower(), parts[1].strip().lower()
425
- if not G.has_node(key): G.add_node(key, type="field")
426
- if not G.has_node(val): G.add_node(val, type="value")
427
- G.add_edge(key, val, chunk=ci)
428
- for tok in val.split():
429
- tok = tok.strip(string.punctuation).lower()
430
- if tok and tok not in _STOP and len(tok) > 2:
431
- entity_index.setdefault(tok, []).append(ci)
432
- logger.info(f"[KG] nodes={G.number_of_nodes()}, edges={G.number_of_edges()}, "
433
- f"tokens={len(entity_index)}")
434
- return G, entity_index
435
-
436
-
437
- def kg_search(query: str, k: int = 5) -> List[int]:
438
- if not state.kg or not state.entity_index: return []
439
- tokens = [t.strip(string.punctuation).lower()
440
- for t in query.split() if t.lower() not in _STOP]
441
- scores: Dict[int, int] = {}
442
- for tok in tokens:
443
- for idx in state.entity_index.get(tok, []):
444
- scores[idx] = scores.get(idx, 0) + 1
445
- if state.kg.has_node(tok):
446
- for nbr in state.kg.successors(tok):
447
- ed = state.kg[tok].get(nbr, {})
448
- ci = ed.get("chunk") if isinstance(ed, dict) else None
449
- if ci is not None:
450
- scores[ci] = scores.get(ci, 0) + 1
451
- return sorted(scores, key=scores.get, reverse=True)[:k]
452
-
453
- # ═════════════════════════════════════════════════════════════════════════════
454
- # INDEX BUILDING + DISK CACHE
455
- # ═════════════════════════════════════════════════════════════════════════════
456
-
457
- def build_indexes_from_scratch() -> bool:
458
  if not state.documents:
459
- logger.warning("[WARN] No documents to index.")
460
  return False
461
  texts = [d["content"] for d in state.documents]
462
 
@@ -464,446 +252,265 @@ def build_indexes_from_scratch() -> bool:
464
  try:
465
  emb = state.embedder.encode(
466
  texts, normalize_embeddings=True,
467
- show_progress_bar=True, batch_size=64,
468
  )
469
  emb = np.array(emb, dtype="float32")
470
  if emb.ndim == 2 and emb.shape[0] > 0:
471
- idx = faiss.IndexFlatIP(emb.shape[1])
472
- idx.add(emb)
473
- state.faiss_index = idx
474
- state.doc_embeddings = emb
475
- _save_faiss(idx)
476
- _save("doc_embeddings.pkl", emb)
477
- logger.info(f"[FAISS] {idx.ntotal} vectors (dim={emb.shape[1]})")
478
  except Exception as e:
479
- logger.error(f"[ERROR] FAISS: {e}")
 
480
 
481
  if BM25_OK:
482
  try:
483
  tok = [t.lower().split() for t in texts if t.strip()]
484
  if tok:
485
- b = BM25Okapi(tok)
486
- state.bm25 = b
487
- _save("bm25.pkl", b)
488
- logger.info(f"[BM25] {len(tok)} docs")
489
  except Exception as e:
490
- logger.error(f"[ERROR] BM25: {e}")
491
-
492
- kg, ei = build_knowledge_graph(state.documents)
493
- state.kg = kg
494
- state.entity_index = ei
495
- if kg:
496
- _save("kg.pkl", kg)
497
- _save("entity_index.pkl", ei)
498
-
499
- _save("documents.pkl", state.documents)
500
  return True
501
 
502
-
503
- def load_indexes_from_cache() -> bool:
504
- docs = _load("documents.pkl")
505
- if not docs: return False
506
- state.documents = docs
507
-
508
- if FAISS_OK:
509
- idx = _load_faiss()
510
- if idx: state.faiss_index = idx
511
- emb = _load("doc_embeddings.pkl")
512
- if emb is not None: state.doc_embeddings = emb
513
-
514
- bm25 = _load("bm25.pkl")
515
- if bm25: state.bm25 = bm25
516
-
517
- kg = _load("kg.pkl")
518
- ei = _load("entity_index.pkl")
519
- if kg:
520
- state.kg = kg
521
- state.entity_index = ei or {}
522
-
523
- return bool(state.documents) and (
524
- state.faiss_index is not None or state.bm25 is not None
525
- )
526
-
527
- # ═════════════════════════════════════════════════════════════════════════════
528
  # RETRIEVAL
529
- # ═════════════════════════════════════════════════════════════════════════════
530
 
531
- def _encode_query(query: str, hyde_text: str = "") -> np.ndarray:
532
- q_emb = state.embedder.encode([query], normalize_embeddings=True)
533
- if hyde_text and ENABLE_HYDE:
534
- h_emb = state.embedder.encode([hyde_text], normalize_embeddings=True)
535
- blended = 0.6 * q_emb + 0.4 * h_emb
536
- blended = blended / (np.linalg.norm(blended, axis=1, keepdims=True) + 1e-9)
537
- return np.array(blended, dtype="float32")
538
- return np.array(q_emb, dtype="float32")
539
-
540
-
541
- def _dense(q_vec: np.ndarray, k: int = TOP_K_RETRIEVE) -> List[Dict]:
542
- if not state.faiss_index: return []
543
  try:
 
 
 
544
  k_a = min(k, state.faiss_index.ntotal)
545
- scores, ids = state.faiss_index.search(q_vec, k_a)
 
 
546
  return [{**state.documents[i], "score": float(s)}
547
  for s, i in zip(scores[0], ids[0]) if i >= 0]
548
  except Exception as e:
549
- logger.error(f"[ERROR] dense: {e}"); return []
 
550
 
551
 
552
- def _sparse(query: str, k: int = TOP_K_RETRIEVE) -> List[Dict]:
553
- if not state.bm25: return []
 
554
  try:
555
  tokens = query.lower().split()
556
- if not tokens: return []
 
557
  scores = np.array(state.bm25.get_scores(tokens), dtype="float32")
558
- idx = np.argsort(scores)[::-1][:min(k, len(scores))]
559
  return [{**state.documents[i], "score": float(scores[i])}
560
  for i in idx if scores[i] > 0]
561
  except Exception as e:
562
- logger.error(f"[ERROR] sparse: {e}"); return []
 
563
 
564
 
565
- def rrf_fuse(lists: List[List[Dict]], weights: List[float], rrf_k=60) -> List[Dict]:
566
- merged, doc_map = {}, {}
567
- for lst, w in zip(lists, weights):
568
- for rank, d in enumerate(lst):
569
- key = d["content"]
570
- merged[key] = merged.get(key, 0.0) + w / (rrf_k + rank + 1)
571
- doc_map[key] = d
 
 
 
 
 
 
 
572
  return [{**doc_map[c], "rrf_score": round(s, 6)}
573
- for c, s in sorted(merged.items(), key=lambda x: x[1], reverse=True)]
574
-
575
-
576
- def expand_queries(query: str) -> List[str]:
577
- variants = [query]
578
- ACRONYMS = {
579
- "ewu": "East West University",
580
- "cse": "Computer Science Engineering",
581
- "eee": "Electrical Electronic Engineering",
582
- "ece": "Electronic Communication Engineering",
583
- "mba": "Master of Business Administration",
584
- "gpa": "grade point average",
585
- "cgpa": "cumulative grade point average",
586
- "vc": "Vice Chancellor",
587
- "dept": "department",
588
- }
589
- q_low = query.lower()
590
- expanded = q_low
591
- for abbr, full in ACRONYMS.items():
592
- expanded = re.sub(r"\b" + abbr + r"\b", full, expanded)
593
- if expanded != q_low:
594
- variants.append(expanded)
595
- _QW = {"what","who","when","where","how","why","is","are","does",
596
- "do","the","a","an","tell","me","about"}
597
- kw = [w for w in re.findall(r"\w+", q_low) if w not in _QW and len(w) > 2]
598
- if kw and " ".join(kw) != q_low:
599
- variants.append(" ".join(kw))
600
- return list(dict.fromkeys(variants))[:3]
601
-
602
-
603
- def rerank(query: str, candidates: List[Dict], top_n: int) -> List[Dict]:
604
- if not state.reranker or not candidates:
605
- return candidates[:top_n]
606
- try:
607
- pairs = [(query, d["content"]) for d in candidates]
608
- scores = state.reranker.predict(pairs, batch_size=32, show_progress_bar=False)
609
- scored = sorted(zip(scores, candidates), key=lambda x: x[0], reverse=True)
610
- out = [{**doc, "rerank_score": float(sc)}
611
- for sc, doc in scored[:top_n] if sc >= RERANK_THRESH]
612
- return out or [{**doc, "rerank_score": float(sc)} for sc, doc in scored[:top_n]]
613
- except Exception as e:
614
- logger.error(f"[ERROR] rerank: {e}")
615
- return candidates[:top_n]
616
-
617
-
618
- def mmr_select(q_vec: np.ndarray, candidates: List[Dict], k: int) -> List[Dict]:
619
- if state.doc_embeddings is None or not candidates:
620
- return candidates[:k]
621
- c2i = {d["content"]: i for i, d in enumerate(state.documents)}
622
- idxs = [c2i[d["content"]] for d in candidates if d["content"] in c2i]
623
- if not idxs: return candidates[:k]
624
- ce = state.doc_embeddings[idxs]
625
- q = q_vec[0]
626
- rel = ce @ q
627
- selected, sel_embs, remaining = [], [], list(range(len(idxs)))
628
- for _ in range(min(k, len(remaining))):
629
- if not remaining: break
630
- if not sel_embs:
631
- best = max(remaining, key=lambda i: rel[i])
632
- else:
633
- S = np.array(sel_embs)
634
- best, bs = remaining[0], -1e9
635
- for i in remaining:
636
- score = MMR_LAMBDA * rel[i] - (1 - MMR_LAMBDA) * float(np.max(S @ ce[i]))
637
- if score > bs: bs, best = score, i
638
- selected.append(best); sel_embs.append(ce[best]); remaining.remove(best)
639
- return [candidates[i] for i in selected]
640
-
641
-
642
- def compress_chunk(q_vec: np.ndarray, text: str) -> str:
643
- if state.embedder is None: return text
644
- lines = [l.strip() for l in re.split(r"[|\n]|(?<=[.!?])\s+", text) if l.strip()]
645
- if len(lines) <= 2: return text
646
- try:
647
- embs = state.embedder.encode(lines, normalize_embeddings=True)
648
- sims = embs @ q_vec[0]
649
- kept = [l for l, s in zip(lines, sims) if s >= COMPRESS_THRESH]
650
- return " | ".join(kept) if kept else text
651
- except Exception:
652
- return text
653
 
654
 
655
- async def full_retrieval(query: str, k: int = TOP_K_FINAL) -> List[Dict]:
656
- variants = await asyncio.to_thread(expand_queries, query)
657
 
658
- all_dense, all_sparse = [], []
659
- for v in variants:
660
- if state.embedder:
661
- vec = await asyncio.to_thread(_encode_query, v)
662
- all_dense.append(await asyncio.to_thread(_dense, vec, TOP_K_RETRIEVE))
663
- all_sparse.append(await asyncio.to_thread(_sparse, v, TOP_K_RETRIEVE))
664
-
665
- weights = [1.0 / (i + 1) for i in range(len(variants))]
666
- fused = rrf_fuse(all_dense + all_sparse, weights + weights)
667
-
668
- kg_idxs = await asyncio.to_thread(kg_search, query, k * 2)
669
- existing = {d["content"] for d in fused}
670
- for i in kg_idxs:
671
- if 0 <= i < len(state.documents):
672
- d = state.documents[i]
673
- if d["content"] not in existing:
674
- fused.append({**d, "rrf_score": 0.0, "kg_injected": True})
675
-
676
- reranked = await asyncio.to_thread(rerank, query, fused, top_n=k * 3)
677
-
678
- if state.embedder:
679
- q_vec = await asyncio.to_thread(_encode_query, query)
680
- final_set = await asyncio.to_thread(mmr_select, q_vec, reranked, k)
681
- compressed = []
682
- for doc in final_set:
683
- ct = await asyncio.to_thread(compress_chunk, q_vec, doc["content"])
684
- compressed.append({**doc, "content": ct})
685
- return compressed
686
-
687
- return reranked[:k]
688
-
689
- # ═════════════════════════════════════════════════════════════════════════════
690
- # GENERATION
691
- # ═════════════════════════════════════════════════════��═══════════════════════
692
-
693
- SYSTEM_PROMPT = """You are EWU Assistant for East West University.
694
- RULES:
695
- 1. Answer ONLY from the provided context.
696
- 2. If the context lacks the answer, say "I don't have that information."
697
- 3. Be specific β€” include numbers, names, dates when present.
698
- 4. Do NOT repeat context verbatim. Summarise clearly.
699
- 5. Never hallucinate facts not in the context."""
700
 
 
 
 
 
 
 
701
 
702
  def _run_tinyllama(query: str, context: str) -> str:
 
 
 
 
 
703
  if state.generator is None:
704
- return f"Based on available information:\n\n{context[:800]}"
705
- trimmed = context[:2000] + ("…" if len(context) > 2000 else "")
 
 
 
 
706
  messages = [
707
- {"role": "system", "content": SYSTEM_PROMPT},
708
- {"role": "user", "content": f"Context:\n{trimmed}\n\nQuestion: {query}\n\nAnswer:"},
709
  ]
 
710
  try:
711
- out = state.generator(messages, max_new_tokens=280,
712
- do_sample=True, temperature=0.25,
713
- top_p=0.90, repetition_penalty=1.15)
714
- generated = out[0]["generated_text"]
 
 
 
 
 
 
 
 
715
  if isinstance(generated, list):
716
  for turn in reversed(generated):
717
  if isinstance(turn, dict) and turn.get("role") == "assistant":
718
  return turn.get("content", "").strip()
 
 
719
  return str(generated).strip()
 
720
  except Exception as e:
721
- logger.error(f"[ERROR] TinyLlama: {e}")
722
  return f"[Generation error: {e}]"
723
 
724
 
725
  async def generate(query: str, context: str) -> str:
 
726
  return await asyncio.to_thread(_run_tinyllama, query, context)
727
 
728
- # ═════════════════════════════════════════════════════════════════════════════
729
- # BOOT β€” two-phase, cache-aware
730
- # ═════════════════════════════════════════════════════════════════════════════
731
-
732
- def _load_models():
733
- emb, ce = None, None
734
- if ST_OK:
735
- try:
736
- logger.info(f" Loading embedder ({EMBED_MODEL}) on {DEVICE}…")
737
- emb = SentenceTransformer(EMBED_MODEL, device=DEVICE)
738
- logger.info(" Embedder ready.")
739
- except Exception as e:
740
- logger.error(f"[ERROR] Embedder: {e}")
741
- try:
742
- logger.info(f" Loading cross-encoder ({RERANK_MODEL})…")
743
- ce = CrossEncoder(RERANK_MODEL, device=DEVICE, max_length=512)
744
- logger.info(" Cross-encoder ready.")
745
- except Exception as e:
746
- logger.warning(f"[WARN] Cross-encoder: {e}")
747
- return emb, ce
748
-
749
 
750
  def _load_generator():
751
- if not HF_OK: return None
 
 
 
752
  try:
753
- logger.info(f" Loading TinyLlama ({GEN_MODEL}) on {DEVICE}…")
754
- gen = hf_pipeline("text-generation", model=GEN_MODEL,
755
- device=0 if DEVICE == "cuda" else -1, dtype="auto")
756
- logger.info(" TinyLlama ready.")
 
 
 
 
757
  return gen
758
  except Exception as e:
759
- logger.error(f"[ERROR] TinyLlama: {e}")
760
  return None
761
 
762
 
763
- async def _boot_phase1():
764
- """Load models + indexes. Sets state.ready = True when done."""
765
- logger.info(f"=== PHASE 1: Models + Indexes (device={DEVICE}) ===")
766
- emb, ce = await asyncio.to_thread(_load_models)
767
- state.embedder = emb
768
- state.reranker = ce
769
-
770
- # Try disk cache
771
- cache_ok = (
772
- _cache_fresh("documents.pkl")
773
- and _cache_fresh("faiss.index")
774
- and _cache_fresh("bm25.pkl")
775
- )
776
- if cache_ok:
777
- logger.info("[cache] Loading from disk…")
778
- if await asyncio.to_thread(load_indexes_from_cache):
779
- logger.info(f"[cache] {len(state.documents)} chunks loaded from disk.")
780
- state.ready = True
781
- logger.info("βœ“ Phase 1 complete (cache hit).")
782
- return
783
-
784
- logger.info("Fetching knowledge base (API + GitHub)…")
785
- api_docs, gh_docs = await asyncio.gather(load_api(), load_github())
786
- raw = api_docs + gh_docs
787
- logger.info(f" Raw docs combined: {len(raw)}")
788
-
789
- if not raw:
790
- logger.warning("[WARN] No documents fetched.")
791
- state.ready = True
792
- return
793
-
794
- logger.info("Chunking…")
795
- state.documents = await asyncio.to_thread(chunk_documents, raw)
796
- logger.info(f" Total chunks: {len(state.documents)}")
797
 
798
- logger.info("Building indexes…")
799
- await asyncio.to_thread(build_indexes_from_scratch)
 
 
 
 
 
800
 
801
- state.ready = True
802
- logger.info("βœ“ Phase 1 complete β€” server accepting queries.")
803
 
 
 
 
804
 
805
- async def _boot_phase2():
806
- """Load TinyLlama in the background β€” does not block /rag."""
807
- logger.info("=== PHASE 2: TinyLlama (background) ===")
808
- gen = await asyncio.to_thread(_load_generator)
809
- state.generator = gen
810
- state.gen_ready = gen is not None
811
- if gen:
812
- logger.info("βœ“ Phase 2 complete β€” full LLM answers active.")
813
- else:
814
- logger.warning("Phase 2: TinyLlama unavailable β€” context-only mode.")
815
 
 
 
816
 
817
- async def _boot():
818
- try:
819
- await _boot_phase1()
820
- asyncio.create_task(_boot_phase2()) # fire-and-forget
821
  except Exception as e:
822
  state.error = str(e)
823
  state.ready = False
824
- logger.error(f"[ERROR] Boot failed: {e}")
825
  import traceback; traceback.print_exc()
826
 
827
 
828
  @asynccontextmanager
829
  async def lifespan(app: FastAPI):
830
- task = asyncio.create_task(_boot())
831
  try:
832
  yield
833
  finally:
834
- task.cancel()
835
  try:
836
- await task
837
  except asyncio.CancelledError:
838
  pass
839
 
840
- # ═════════════════════════════════════════════════════════════════════════════
841
- # ENDPOINTS
842
- # ═════════════════════════════════════════════════════════════════════════════
843
 
844
- app = FastAPI(title="EWU RAG Server v2.1", lifespan=lifespan)
845
 
846
 
847
  class Query(BaseModel):
848
  query : str
849
- top_k : int = TOP_K_FINAL
850
 
851
 
852
  @app.post("/rag")
853
  async def rag_endpoint(q: Query):
854
  if not state.ready:
855
- raise HTTPException(503, detail=state.error or "Initialising β€” retry in 30s.")
856
  if not q.query.strip():
857
  raise HTTPException(400, detail="Query must not be empty.")
858
- results = await full_retrieval(q.query, k=q.top_k)
859
  if not results:
860
  return {"answer": "No relevant information found.", "sources": []}
861
  context = "\n\n---\n\n".join(r["content"] for r in results)
862
  answer = await generate(q.query, context)
863
  return {
864
- "answer" : answer,
865
- "gen_ready": state.gen_ready,
866
- "sources" : [
867
- {
868
- "source" : r.get("source"),
869
- "rerank_score": round(r.get("rerank_score", 0), 4),
870
- "rrf_score" : round(r.get("rrf_score", 0), 6),
871
- "kg_injected" : r.get("kg_injected", False),
872
- }
873
- for r in results
874
- ],
875
  }
876
 
877
 
878
  @app.get("/health")
879
  async def health():
880
  return JSONResponse(200, {
881
- "status" : "ready" if state.ready else ("error" if state.error else "loading"),
882
- "gen_ready" : state.gen_ready,
883
- "docs" : len(state.documents),
884
- "device" : DEVICE,
885
- "faiss" : state.faiss_index is not None,
886
- "bm25" : state.bm25 is not None,
887
- "reranker" : state.reranker is not None,
888
- "generator" : state.generator is not None,
889
- "hyde_enabled": ENABLE_HYDE,
890
- "kg_nodes" : state.kg.number_of_nodes() if state.kg else 0,
891
- "kg_edges" : state.kg.number_of_edges() if state.kg else 0,
892
- "error" : state.error or None,
893
  })
894
 
895
 
896
- @app.post("/cache/clear")
897
- async def clear_cache():
898
- """Delete disk cache β€” server will rebuild from scratch on next restart."""
899
- import shutil
900
- try:
901
- shutil.rmtree(CACHE_DIR, ignore_errors=True)
902
- os.makedirs(CACHE_DIR, exist_ok=True)
903
- return {"status": "cache cleared β€” restart the server to rebuild"}
904
- except Exception as e:
905
- raise HTTPException(500, detail=str(e))
906
-
907
-
908
  if __name__ == "__main__":
909
  uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import json
3
+ import asyncio
 
 
 
 
 
 
4
  import numpy as np
5
  import uvicorn
6
+ import httpx
7
+
8
+ from contextlib import asynccontextmanager
9
  from fastapi import FastAPI, HTTPException
10
  from fastapi.responses import JSONResponse
11
  from pydantic import BaseModel
12
 
 
 
 
 
 
13
  try:
14
  import faiss
15
  FAISS_OK = True
 
17
  FAISS_OK = False
18
 
19
  try:
20
+ from sentence_transformers import SentenceTransformer
21
  ST_OK = True
22
  except ImportError:
23
  ST_OK = False
 
33
  HF_OK = True
34
  except ImportError:
35
  HF_OK = False
36
+ print("[WARN] transformers not installed β€” generation disabled.")
37
 
38
+ # ─────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  # CONFIG
40
+ # ─────────────────────────────────────────────
41
 
42
  API_BASE = "https://ewu-server.onrender.com/api"
43
  API_KEY = "i6EDytaX4E2jI6GvZQc0b1RSZHTI5_wVRa2rfL7rLpk"
44
  API_HEADERS = {"x-api-key": API_KEY}
 
 
 
 
 
45
 
46
+ GITHUB_BASE = "https://raw.githubusercontent.com/Atkiya/jsonfiles/main/"
47
+ EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
48
+ GEN_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
 
 
 
 
49
 
50
+ CHUNK_SIZE = 400
51
+ CHUNK_OVERLAP = 80
52
+ DEVICE = "cpu"
 
53
 
54
+ try:
55
+ import torch
56
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
57
+ except ImportError:
58
+ pass
59
 
60
  API_LIST_ENDPOINTS = [
61
  "admission-deadlines", "academic-calendar", "grade-scale",
 
63
  "governance", "alumni", "helpdesk", "policies", "proctor-schedule",
64
  "documents", "newsletters", "programs", "faculty", "departments",
65
  ]
66
+
67
  API_DETAIL_ENDPOINTS = [
68
  {"list": "programs", "id_field": "id"},
69
  {"list": "faculty", "id_field": "id"},
70
  {"list": "documents", "id_field": "slug"},
71
  ]
72
+
73
  GITHUB_FILES = [
74
  "admission_deadlines.json", "dynamic_admission_process.json",
75
  "dynamic_admission_requirements.json", "dynamic_tution_fees.json",
 
79
  "static_aboutEWU.json", "static_Admin.json",
80
  "static_AllAvailablePrograms.json", "static_alumni.json",
81
  "static_campus_life.json", "static_Career_Counseling_Center.json",
82
+ "static_clubs.json", "static_depts.json", "static_facilities.json",
 
83
  "static_helpdesk.json", "static_payment_procedure.json",
84
  "static_Policy.json", "static_Programs.json", "static_Rules.json",
85
  "static_Sexual_harassment.json", "static_Tuition_fees.json",
 
93
  "syndicate.json", "tesol.json", "ewu_board_of_trustees.json",
94
  ]
95
 
96
+ # ─────────────────────────────────────────────
97
  # APP STATE
98
+ # ─────────────────────────────────────────────
99
 
100
  class AppState:
101
+ embedder = None
102
+ generator = None # TinyLlama pipeline
103
+ documents : list = []
104
+ faiss_index = None
105
+ bm25 = None
106
+ ready : bool = False
107
+ error : str = ""
 
 
 
 
 
108
 
109
  state = AppState()
110
 
111
+ # ─────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  # DATA LOADING
113
+ # ─────────────────────────────────────────────
114
 
115
+ async def fetch_json(url: str, headers: dict = None, timeout: int = 60):
 
 
 
116
  try:
117
  async with httpx.AsyncClient(timeout=timeout) as client:
118
  r = await client.get(url, headers=headers or {})
119
  if r.status_code == 200:
 
120
  return r.json()
121
+ print(f"[WARN] {url} β†’ HTTP {r.status_code}")
 
 
 
122
  except Exception as e:
123
+ print(f"[WARN] {url} β†’ {e}")
 
124
  return None
125
 
126
 
127
+ def _unwrap(data) -> list:
128
+ if isinstance(data, list):
129
+ return data
130
  if isinstance(data, dict):
131
+ for key in ("data", "results", "items"):
132
+ if key in data and isinstance(data[key], list):
133
+ return data[key]
134
  return [data]
135
  return []
136
 
137
 
138
+ async def _wake_api_server():
139
+ """
140
+ render.com free tier sleeps after inactivity.
141
+ Hit a cheap endpoint first and wait for it to wake up (can take ~50s).
142
+ """
143
+ print(" [API] Waking render.com server (free tier may be sleeping)…")
144
  for attempt in range(3):
145
+ result = await fetch_json(f"{API_BASE}/grade-scale", API_HEADERS, timeout=60)
146
+ if result is not None:
147
+ print(" [API] Server awake.")
148
  return True
149
+ print(f" [API] Wake attempt {attempt+1}/3 failed, retrying…")
150
  await asyncio.sleep(10)
151
+ print(" [API] Server did not wake β€” skipping API data.")
152
  return False
153
 
154
 
155
+ async def load_api() -> list:
156
+ awake = await _wake_api_server()
157
+ if not awake:
158
  return []
159
 
160
  list_results = await asyncio.gather(
161
+ *[fetch_json(f"{API_BASE}/{ep}", API_HEADERS) for ep in API_LIST_ENDPOINTS],
162
  return_exceptions=True,
163
  )
164
  docs, list_cache = [], {}
165
  for ep, data in zip(API_LIST_ENDPOINTS, list_results):
166
+ if not data or isinstance(data, Exception):
167
+ continue
168
  items = _unwrap(data)
169
  list_cache[ep] = items
170
  for item in items:
171
  text = json.dumps(item, ensure_ascii=False)
172
  if text.strip():
173
  docs.append({"content": text, "source": f"api:{ep}"})
174
+ print(f" [API lists] {len(docs)} docs")
175
 
 
176
  detail_tasks = []
177
  for cfg in API_DETAIL_ENDPOINTS:
178
  for item in list_cache.get(cfg["list"], []):
179
  item_id = item.get(cfg["id_field"]) if isinstance(item, dict) else None
180
+ if item_id is not None:
181
+ url = f"{API_BASE}/{cfg['list']}/{item_id}"
 
182
  detail_tasks.append((url, f"api:{cfg['list']}/{item_id}"))
183
 
184
  if detail_tasks:
185
+ detail_results = await asyncio.gather(
186
+ *[fetch_json(url, API_HEADERS) for url, _ in detail_tasks],
 
 
 
 
 
187
  return_exceptions=True,
188
  )
189
  n = 0
190
+ for (_, source), data in zip(detail_tasks, detail_results):
191
+ if not data or isinstance(data, Exception):
192
+ continue
 
193
  for item in _unwrap(data):
194
  text = json.dumps(item, ensure_ascii=False)
195
  if text.strip():
196
  docs.append({"content": text, "source": source})
197
  n += 1
198
+ print(f" [API details] {n} docs from {len(detail_tasks)} pages")
199
 
200
+ print(f" [API total] {len(docs)} raw docs")
201
  return docs
202
 
203
 
204
+ async def load_github() -> list:
205
  responses = await asyncio.gather(
206
+ *[fetch_json(GITHUB_BASE + f) for f in GITHUB_FILES],
207
  return_exceptions=True,
208
  )
209
  docs = []
210
  for fname, data in zip(GITHUB_FILES, responses):
211
+ if not data or isinstance(data, Exception):
212
+ continue
213
  for item in (data if isinstance(data, list) else [data]):
214
  text = json.dumps(item, ensure_ascii=False)
215
  if text.strip():
216
  docs.append({"content": text, "source": f"github:{fname}"})
217
+ print(f" [GitHub] {len(docs)} raw docs")
218
  return docs
219
 
220
+ # ─────────────────────────────────────────────
221
  # CHUNKING
222
+ # ─────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
 
224
+ def chunk_documents(docs, size=CHUNK_SIZE, overlap=CHUNK_OVERLAP):
225
+ step, out = max(1, size - overlap), []
 
226
  for d in docs:
227
+ text = d["content"]
228
+ if not text.strip():
229
+ continue
230
+ if len(text) <= size:
231
+ out.append(d)
232
+ continue
233
  start = 0
234
  while start < len(text):
235
+ chunk = text[start:start+size]
236
+ if chunk.strip():
237
+ out.append({"content": chunk, "source": d["source"]})
238
  start += step
239
  return out
240
 
241
+ # ─────────────────────────────────────────────
242
+ # INDEX BUILDING
243
+ # ─────────────────────────────────────────────
244
+
245
+ def build_indexes():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
  if not state.documents:
247
+ print("[WARN] No documents to index.")
248
  return False
249
  texts = [d["content"] for d in state.documents]
250
 
 
252
  try:
253
  emb = state.embedder.encode(
254
  texts, normalize_embeddings=True,
255
+ show_progress_bar=False, batch_size=64,
256
  )
257
  emb = np.array(emb, dtype="float32")
258
  if emb.ndim == 2 and emb.shape[0] > 0:
259
+ state.faiss_index = faiss.IndexFlatIP(emb.shape[1])
260
+ state.faiss_index.add(emb)
261
+ print(f" [FAISS] {state.faiss_index.ntotal} vectors (dim={emb.shape[1]})")
 
 
 
 
262
  except Exception as e:
263
+ print(f"[ERROR] FAISS: {e}")
264
+ state.faiss_index = None
265
 
266
  if BM25_OK:
267
  try:
268
  tok = [t.lower().split() for t in texts if t.strip()]
269
  if tok:
270
+ state.bm25 = BM25Okapi(tok)
271
+ print(f" [BM25] {len(tok)} docs")
 
 
272
  except Exception as e:
273
+ print(f"[ERROR] BM25: {e}")
274
+ state.bm25 = None
 
 
 
 
 
 
 
 
275
  return True
276
 
277
+ # ─────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
  # RETRIEVAL
279
+ # ─────────────────────────────────────────────
280
 
281
+ def search_dense(query, k=8):
282
+ if not state.faiss_index or not state.embedder:
283
+ return []
 
 
 
 
 
 
 
 
 
284
  try:
285
+ vec = np.array(
286
+ state.embedder.encode([query], normalize_embeddings=True), dtype="float32"
287
+ )
288
  k_a = min(k, state.faiss_index.ntotal)
289
+ if not k_a:
290
+ return []
291
+ scores, ids = state.faiss_index.search(vec, k_a)
292
  return [{**state.documents[i], "score": float(s)}
293
  for s, i in zip(scores[0], ids[0]) if i >= 0]
294
  except Exception as e:
295
+ print(f"[ERROR] dense: {e}")
296
+ return []
297
 
298
 
299
+ def search_sparse(query, k=8):
300
+ if not state.bm25 or not state.documents:
301
+ return []
302
  try:
303
  tokens = query.lower().split()
304
+ if not tokens:
305
+ return []
306
  scores = np.array(state.bm25.get_scores(tokens), dtype="float32")
307
+ idx = np.argsort(scores)[::-1][:min(k, len(scores))]
308
  return [{**state.documents[i], "score": float(scores[i])}
309
  for i in idx if scores[i] > 0]
310
  except Exception as e:
311
+ print(f"[ERROR] sparse: {e}")
312
+ return []
313
 
314
 
315
+ def hybrid_search(query, k=5, alpha=0.65):
316
+ dense = search_dense(query, k * 3)
317
+ sparse = search_sparse(query, k * 3)
318
+ if not dense and not sparse:
319
+ return []
320
+ rrf_k, merged, doc_map = 60, {}, {}
321
+ for rank, d in enumerate(dense):
322
+ key = d["content"]
323
+ merged[key] = merged.get(key, 0.0) + alpha / (rrf_k + rank + 1)
324
+ doc_map[key] = d
325
+ for rank, d in enumerate(sparse):
326
+ key = d["content"]
327
+ merged[key] = merged.get(key, 0.0) + (1 - alpha) / (rrf_k + rank + 1)
328
+ doc_map[key] = d
329
  return [{**doc_map[c], "rrf_score": round(s, 6)}
330
+ for c, s in sorted(merged.items(), key=lambda x: x[1], reverse=True)[:k]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331
 
332
 
333
+ async def async_hybrid_search(query, k=5):
334
+ return await asyncio.to_thread(hybrid_search, query, k)
335
 
336
+ # ─────────────────────────────────────────────
337
+ # GENERATION β€” TinyLlama (local, no API key)
338
+ # ─────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
339
 
340
+ SYSTEM_PROMPT = (
341
+ "You are a helpful assistant for East West University (EWU). "
342
+ "Answer using ONLY the context provided. "
343
+ "If the context does not contain enough information, say so honestly. "
344
+ "Be concise and accurate. Do not repeat the context."
345
+ )
346
 
347
  def _run_tinyllama(query: str, context: str) -> str:
348
+ """
349
+ Synchronous TinyLlama call.
350
+ Uses the chat template format TinyLlama-1.1B-Chat was trained on.
351
+ Always call via asyncio.to_thread β€” never directly from async code.
352
+ """
353
  if state.generator is None:
354
+ return f"[Generator not loaded]\n\nContext:\n{context}"
355
+
356
+ # TinyLlama chat template: <|system|>...<|user|>...<|assistant|>
357
+ # Trim context to ~1500 chars so it fits in the 2048-token window
358
+ trimmed_context = context[:1500] + ("…" if len(context) > 1500 else "")
359
+
360
  messages = [
361
+ {"role": "system", "content": SYSTEM_PROMPT},
362
+ {"role": "user", "content": f"Context:\n{trimmed_context}\n\nQuestion: {query}"},
363
  ]
364
+
365
  try:
366
+ outputs = state.generator(
367
+ messages,
368
+ max_new_tokens=256,
369
+ do_sample=True,
370
+ temperature=0.3, # low = more factual, less hallucination
371
+ top_p=0.9,
372
+ repetition_penalty=1.1,
373
+ )
374
+ # transformers pipeline returns list of dicts with generated_text
375
+ generated = outputs[0]["generated_text"]
376
+
377
+ # generated_text is the full conversation list; grab the last assistant turn
378
  if isinstance(generated, list):
379
  for turn in reversed(generated):
380
  if isinstance(turn, dict) and turn.get("role") == "assistant":
381
  return turn.get("content", "").strip()
382
+
383
+ # Fallback: return raw string
384
  return str(generated).strip()
385
+
386
  except Exception as e:
387
+ print(f"[ERROR] TinyLlama inference: {e}")
388
  return f"[Generation error: {e}]"
389
 
390
 
391
  async def generate(query: str, context: str) -> str:
392
+ """Async wrapper β€” runs TinyLlama in a thread so the event loop stays free."""
393
  return await asyncio.to_thread(_run_tinyllama, query, context)
394
 
395
+ # ─────────────────────────────────────────────
396
+ # BOOT
397
+ # ─────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
398
 
399
  def _load_generator():
400
+ """Load TinyLlama pipeline. Runs in a thread during boot."""
401
+ if not HF_OK:
402
+ print("[WARN] transformers unavailable β€” generation disabled.")
403
+ return None
404
  try:
405
+ print(f" Loading TinyLlama on {DEVICE}…")
406
+ gen = hf_pipeline(
407
+ "text-generation",
408
+ model=GEN_MODEL,
409
+ device=0 if DEVICE == "cuda" else -1, # -1 = CPU for transformers pipeline
410
+ dtype="auto",
411
+ )
412
+ print(" TinyLlama ready.")
413
  return gen
414
  except Exception as e:
415
+ print(f"[ERROR] Could not load TinyLlama: {e}")
416
  return None
417
 
418
 
419
+ async def _boot():
420
+ try:
421
+ # 1. Load both models concurrently in threads
422
+ print(f"Loading models on {DEVICE}…")
423
+ state.embedder, state.generator = await asyncio.gather(
424
+ asyncio.to_thread(SentenceTransformer, EMBED_MODEL, device=DEVICE) if ST_OK
425
+ else asyncio.to_thread(lambda: None),
426
+ asyncio.to_thread(_load_generator),
427
+ )
428
+ if state.embedder:
429
+ print(" Embedder ready.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
430
 
431
+ # 2. Fetch API + GitHub concurrently
432
+ print("Fetching knowledge base (API + GitHub)…")
433
+ api_docs, gh_docs = await asyncio.gather(
434
+ load_api(), load_github(), return_exceptions=False,
435
+ )
436
+ raw_docs = api_docs + gh_docs
437
+ print(f" Combined raw docs: {len(raw_docs)}")
438
 
439
+ if not raw_docs:
440
+ print("[WARN] No documents fetched.")
441
 
442
+ # 3. Chunk
443
+ state.documents = await asyncio.to_thread(chunk_documents, raw_docs)
444
+ print(f" Total chunks: {len(state.documents)}")
445
 
446
+ # 4. Build indexes
447
+ print("Building indexes…")
448
+ await asyncio.to_thread(build_indexes)
 
 
 
 
 
 
 
449
 
450
+ state.ready = True
451
+ print("βœ“ RAG server ready.")
452
 
 
 
 
 
453
  except Exception as e:
454
  state.error = str(e)
455
  state.ready = False
456
+ print(f"[ERROR] Boot failed: {e}")
457
  import traceback; traceback.print_exc()
458
 
459
 
460
  @asynccontextmanager
461
  async def lifespan(app: FastAPI):
462
+ boot_task = asyncio.create_task(_boot())
463
  try:
464
  yield
465
  finally:
466
+ boot_task.cancel()
467
  try:
468
+ await boot_task
469
  except asyncio.CancelledError:
470
  pass
471
 
472
+ # ─────────────────────────────────────────────
473
+ # APP + ENDPOINTS
474
+ # ─────────────────────────────────────────────
475
 
476
+ app = FastAPI(title="EWU RAG Server", lifespan=lifespan)
477
 
478
 
479
  class Query(BaseModel):
480
  query : str
481
+ top_k : int = 5
482
 
483
 
484
  @app.post("/rag")
485
  async def rag_endpoint(q: Query):
486
  if not state.ready:
487
+ raise HTTPException(503, detail=state.error or "Still initializing β€” retry shortly.")
488
  if not q.query.strip():
489
  raise HTTPException(400, detail="Query must not be empty.")
490
+ results = await async_hybrid_search(q.query, k=q.top_k)
491
  if not results:
492
  return {"answer": "No relevant information found.", "sources": []}
493
  context = "\n\n---\n\n".join(r["content"] for r in results)
494
  answer = await generate(q.query, context)
495
  return {
496
+ "answer": answer,
497
+ "sources": [{"source": r.get("source"), "rrf_score": r.get("rrf_score", 0)}
498
+ for r in results],
 
 
 
 
 
 
 
 
499
  }
500
 
501
 
502
  @app.get("/health")
503
  async def health():
504
  return JSONResponse(200, {
505
+ "status" : "ready" if state.ready else ("error" if state.error else "loading"),
506
+ "docs" : len(state.documents),
507
+ "device" : DEVICE,
508
+ "faiss" : state.faiss_index is not None,
509
+ "bm25" : state.bm25 is not None,
510
+ "generator" : state.generator is not None,
511
+ "error" : state.error or None,
 
 
 
 
 
512
  })
513
 
514
 
 
 
 
 
 
 
 
 
 
 
 
 
515
  if __name__ == "__main__":
516
  uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=False)