Spaces:
Running
Running
| import os | |
| import json | |
| import asyncio | |
| import numpy as np | |
| import uvicorn | |
| import httpx | |
| from contextlib import asynccontextmanager | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.responses import JSONResponse | |
| from pydantic import BaseModel | |
| try: | |
| import faiss | |
| FAISS_OK = True | |
| except ImportError: | |
| FAISS_OK = False | |
| try: | |
| from sentence_transformers import SentenceTransformer | |
| ST_OK = True | |
| except ImportError: | |
| ST_OK = False | |
| try: | |
| from rank_bm25 import BM25Okapi | |
| BM25_OK = True | |
| except ImportError: | |
| BM25_OK = False | |
| try: | |
| from transformers import pipeline as hf_pipeline | |
| HF_OK = True | |
| except ImportError: | |
| HF_OK = False | |
| print("[WARN] transformers not installed β generation disabled.") | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # CONFIG | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| API_BASE = "https://ewu-server.onrender.com/api" | |
| API_KEY = "i6EDytaX4E2jI6GvZQc0b1RSZHTI5_wVRa2rfL7rLpk" | |
| API_HEADERS = {"x-api-key": API_KEY} | |
| GITHUB_BASE = "https://raw.githubusercontent.com/Atkiya/jsonfiles/main/" | |
| EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2" | |
| GEN_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" | |
| CHUNK_SIZE = 400 | |
| CHUNK_OVERLAP = 80 | |
| DEVICE = "cpu" | |
| try: | |
| import torch | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| except ImportError: | |
| pass | |
| API_LIST_ENDPOINTS = [ | |
| "admission-deadlines", "academic-calendar", "grade-scale", | |
| "tuition-fees", "scholarships", "clubs", "notices", "partnerships", | |
| "governance", "alumni", "helpdesk", "policies", "proctor-schedule", | |
| "documents", "newsletters", "programs", "faculty", "departments", | |
| ] | |
| API_DETAIL_ENDPOINTS = [ | |
| {"list": "programs", "id_field": "id"}, | |
| {"list": "faculty", "id_field": "id"}, | |
| {"list": "documents", "id_field": "slug"}, | |
| ] | |
| GITHUB_FILES = [ | |
| "admission_deadlines.json", "dynamic_admission_process.json", | |
| "dynamic_admission_requirements.json", "dynamic_tution_fees.json", | |
| "dynamic_events_workshops.json", "ewu_faculty_complete.json", | |
| "dynamic_grading.json", "dynamic_facilites.json", | |
| "ewu_proctor_schedule.json", "ewu_newsletters_complete.json", | |
| "static_aboutEWU.json", "static_Admin.json", | |
| "static_AllAvailablePrograms.json", "static_alumni.json", | |
| "static_campus_life.json", "static_Career_Counseling_Center.json", | |
| "static_clubs.json", "static_depts.json", "static_facilities.json", | |
| "static_helpdesk.json", "static_payment_procedure.json", | |
| "static_Policy.json", "static_Programs.json", "static_Rules.json", | |
| "static_Sexual_harassment.json", "static_Tuition_fees.json", | |
| "ma_english.json", "mba_emba.json", "ms_cse.json", "ms_dsa.json", | |
| "mds.json", "mphil_pharmacy.json", "mss_eco.json", | |
| "scholarships_and_financial_aids.json", | |
| "st_ba.json", "st_ce.json", "st_cse.json", "st_ece.json", | |
| "st_economics.json", "st_eee.json", "st_english.json", "st_geb.json", | |
| "st_information_studies.json", "st_law.json", "st_math.json", | |
| "st_pharmacy.json", "st_social_relations.json", "st_sociology.json", | |
| "syndicate.json", "tesol.json", "ewu_board_of_trustees.json", | |
| ] | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # APP STATE | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| class AppState: | |
| embedder = None | |
| generator = None # TinyLlama pipeline | |
| documents : list = [] | |
| faiss_index = None | |
| bm25 = None | |
| ready : bool = False | |
| error : str = "" | |
| state = AppState() | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # DATA LOADING | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| async def fetch_json(url: str, headers: dict = None, timeout: int = 60): | |
| try: | |
| async with httpx.AsyncClient(timeout=timeout) as client: | |
| r = await client.get(url, headers=headers or {}) | |
| if r.status_code == 200: | |
| return r.json() | |
| print(f"[WARN] {url} β HTTP {r.status_code}") | |
| except Exception as e: | |
| print(f"[WARN] {url} β {e}") | |
| return None | |
| def _unwrap(data) -> list: | |
| if isinstance(data, list): | |
| return data | |
| if isinstance(data, dict): | |
| for key in ("data", "results", "items"): | |
| if key in data and isinstance(data[key], list): | |
| return data[key] | |
| return [data] | |
| return [] | |
| async def _wake_api_server(): | |
| """ | |
| render.com free tier sleeps after inactivity. | |
| Hit a cheap endpoint first and wait for it to wake up (can take ~50s). | |
| """ | |
| print(" [API] Waking render.com server (free tier may be sleeping)β¦") | |
| for attempt in range(3): | |
| result = await fetch_json(f"{API_BASE}/grade-scale", API_HEADERS, timeout=60) | |
| if result is not None: | |
| print(" [API] Server awake.") | |
| return True | |
| print(f" [API] Wake attempt {attempt+1}/3 failed, retryingβ¦") | |
| await asyncio.sleep(10) | |
| print(" [API] Server did not wake β skipping API data.") | |
| return False | |
| async def load_api() -> list: | |
| awake = await _wake_api_server() | |
| if not awake: | |
| return [] | |
| list_results = await asyncio.gather( | |
| *[fetch_json(f"{API_BASE}/{ep}", API_HEADERS) for ep in API_LIST_ENDPOINTS], | |
| return_exceptions=True, | |
| ) | |
| docs, list_cache = [], {} | |
| for ep, data in zip(API_LIST_ENDPOINTS, list_results): | |
| if not data or isinstance(data, Exception): | |
| continue | |
| items = _unwrap(data) | |
| list_cache[ep] = items | |
| for item in items: | |
| text = json.dumps(item, ensure_ascii=False) | |
| if text.strip(): | |
| docs.append({"content": text, "source": f"api:{ep}"}) | |
| print(f" [API lists] {len(docs)} docs") | |
| detail_tasks = [] | |
| for cfg in API_DETAIL_ENDPOINTS: | |
| for item in list_cache.get(cfg["list"], []): | |
| item_id = item.get(cfg["id_field"]) if isinstance(item, dict) else None | |
| if item_id is not None: | |
| url = f"{API_BASE}/{cfg['list']}/{item_id}" | |
| detail_tasks.append((url, f"api:{cfg['list']}/{item_id}")) | |
| if detail_tasks: | |
| detail_results = await asyncio.gather( | |
| *[fetch_json(url, API_HEADERS) for url, _ in detail_tasks], | |
| return_exceptions=True, | |
| ) | |
| n = 0 | |
| for (_, source), data in zip(detail_tasks, detail_results): | |
| if not data or isinstance(data, Exception): | |
| continue | |
| for item in _unwrap(data): | |
| text = json.dumps(item, ensure_ascii=False) | |
| if text.strip(): | |
| docs.append({"content": text, "source": source}) | |
| n += 1 | |
| print(f" [API details] {n} docs from {len(detail_tasks)} pages") | |
| print(f" [API total] {len(docs)} raw docs") | |
| return docs | |
| async def load_github() -> list: | |
| responses = await asyncio.gather( | |
| *[fetch_json(GITHUB_BASE + f) for f in GITHUB_FILES], | |
| return_exceptions=True, | |
| ) | |
| docs = [] | |
| for fname, data in zip(GITHUB_FILES, responses): | |
| if not data or isinstance(data, Exception): | |
| continue | |
| for item in (data if isinstance(data, list) else [data]): | |
| text = json.dumps(item, ensure_ascii=False) | |
| if text.strip(): | |
| docs.append({"content": text, "source": f"github:{fname}"}) | |
| print(f" [GitHub] {len(docs)} raw docs") | |
| return docs | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # CHUNKING | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| def chunk_documents(docs, size=CHUNK_SIZE, overlap=CHUNK_OVERLAP): | |
| step, out = max(1, size - overlap), [] | |
| for d in docs: | |
| text = d["content"] | |
| if not text.strip(): | |
| continue | |
| if len(text) <= size: | |
| out.append(d) | |
| continue | |
| start = 0 | |
| while start < len(text): | |
| chunk = text[start:start+size] | |
| if chunk.strip(): | |
| out.append({"content": chunk, "source": d["source"]}) | |
| start += step | |
| return out | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # INDEX BUILDING | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| def build_indexes(): | |
| if not state.documents: | |
| print("[WARN] No documents to index.") | |
| return False | |
| texts = [d["content"] for d in state.documents] | |
| if FAISS_OK and ST_OK and state.embedder: | |
| try: | |
| emb = state.embedder.encode( | |
| texts, normalize_embeddings=True, | |
| show_progress_bar=False, batch_size=64, | |
| ) | |
| emb = np.array(emb, dtype="float32") | |
| if emb.ndim == 2 and emb.shape[0] > 0: | |
| state.faiss_index = faiss.IndexFlatIP(emb.shape[1]) | |
| state.faiss_index.add(emb) | |
| print(f" [FAISS] {state.faiss_index.ntotal} vectors (dim={emb.shape[1]})") | |
| except Exception as e: | |
| print(f"[ERROR] FAISS: {e}") | |
| state.faiss_index = None | |
| if BM25_OK: | |
| try: | |
| tok = [t.lower().split() for t in texts if t.strip()] | |
| if tok: | |
| state.bm25 = BM25Okapi(tok) | |
| print(f" [BM25] {len(tok)} docs") | |
| except Exception as e: | |
| print(f"[ERROR] BM25: {e}") | |
| state.bm25 = None | |
| return True | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # RETRIEVAL | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| def search_dense(query, k=8): | |
| if not state.faiss_index or not state.embedder: | |
| return [] | |
| try: | |
| vec = np.array( | |
| state.embedder.encode([query], normalize_embeddings=True), dtype="float32" | |
| ) | |
| k_a = min(k, state.faiss_index.ntotal) | |
| if not k_a: | |
| return [] | |
| scores, ids = state.faiss_index.search(vec, k_a) | |
| return [{**state.documents[i], "score": float(s)} | |
| for s, i in zip(scores[0], ids[0]) if i >= 0] | |
| except Exception as e: | |
| print(f"[ERROR] dense: {e}") | |
| return [] | |
| def search_sparse(query, k=8): | |
| if not state.bm25 or not state.documents: | |
| return [] | |
| try: | |
| tokens = query.lower().split() | |
| if not tokens: | |
| return [] | |
| scores = np.array(state.bm25.get_scores(tokens), dtype="float32") | |
| idx = np.argsort(scores)[::-1][:min(k, len(scores))] | |
| return [{**state.documents[i], "score": float(scores[i])} | |
| for i in idx if scores[i] > 0] | |
| except Exception as e: | |
| print(f"[ERROR] sparse: {e}") | |
| return [] | |
| def hybrid_search(query, k=5, alpha=0.65): | |
| dense = search_dense(query, k * 3) | |
| sparse = search_sparse(query, k * 3) | |
| if not dense and not sparse: | |
| return [] | |
| rrf_k, merged, doc_map = 60, {}, {} | |
| for rank, d in enumerate(dense): | |
| key = d["content"] | |
| merged[key] = merged.get(key, 0.0) + alpha / (rrf_k + rank + 1) | |
| doc_map[key] = d | |
| for rank, d in enumerate(sparse): | |
| key = d["content"] | |
| merged[key] = merged.get(key, 0.0) + (1 - alpha) / (rrf_k + rank + 1) | |
| doc_map[key] = d | |
| return [{**doc_map[c], "rrf_score": round(s, 6)} | |
| for c, s in sorted(merged.items(), key=lambda x: x[1], reverse=True)[:k]] | |
| async def async_hybrid_search(query, k=5): | |
| return await asyncio.to_thread(hybrid_search, query, k) | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # GENERATION β TinyLlama (local, no API key) | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| SYSTEM_PROMPT = ( | |
| "You are a helpful assistant for East West University (EWU). " | |
| "Answer using ONLY the context provided. " | |
| "If the context does not contain enough information, say so honestly. " | |
| "Be concise and accurate. Do not repeat the context." | |
| ) | |
| def _run_tinyllama(query: str, context: str) -> str: | |
| """ | |
| Synchronous TinyLlama call. | |
| Uses the chat template format TinyLlama-1.1B-Chat was trained on. | |
| Always call via asyncio.to_thread β never directly from async code. | |
| """ | |
| if state.generator is None: | |
| return f"[Generator not loaded]\n\nContext:\n{context}" | |
| # TinyLlama chat template: <|system|>...<|user|>...<|assistant|> | |
| # Trim context to ~1500 chars so it fits in the 2048-token window | |
| trimmed_context = context[:1500] + ("β¦" if len(context) > 1500 else "") | |
| messages = [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": f"Context:\n{trimmed_context}\n\nQuestion: {query}"}, | |
| ] | |
| try: | |
| outputs = state.generator( | |
| messages, | |
| max_new_tokens=256, | |
| do_sample=True, | |
| temperature=0.3, # low = more factual, less hallucination | |
| top_p=0.9, | |
| repetition_penalty=1.1, | |
| ) | |
| # transformers pipeline returns list of dicts with generated_text | |
| generated = outputs[0]["generated_text"] | |
| # generated_text is the full conversation list; grab the last assistant turn | |
| if isinstance(generated, list): | |
| for turn in reversed(generated): | |
| if isinstance(turn, dict) and turn.get("role") == "assistant": | |
| return turn.get("content", "").strip() | |
| # Fallback: return raw string | |
| return str(generated).strip() | |
| except Exception as e: | |
| print(f"[ERROR] TinyLlama inference: {e}") | |
| return f"[Generation error: {e}]" | |
| async def generate(query: str, context: str) -> str: | |
| """Async wrapper β runs TinyLlama in a thread so the event loop stays free.""" | |
| return await asyncio.to_thread(_run_tinyllama, query, context) | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # BOOT | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| def _load_generator(): | |
| """Load TinyLlama pipeline. Runs in a thread during boot.""" | |
| if not HF_OK: | |
| print("[WARN] transformers unavailable β generation disabled.") | |
| return None | |
| try: | |
| print(f" Loading TinyLlama on {DEVICE}β¦") | |
| gen = hf_pipeline( | |
| "text-generation", | |
| model=GEN_MODEL, | |
| device=0 if DEVICE == "cuda" else -1, # -1 = CPU for transformers pipeline | |
| dtype="auto", | |
| ) | |
| print(" TinyLlama ready.") | |
| return gen | |
| except Exception as e: | |
| print(f"[ERROR] Could not load TinyLlama: {e}") | |
| return None | |
| async def _boot(): | |
| try: | |
| # 1. Load both models concurrently in threads | |
| print(f"Loading models on {DEVICE}β¦") | |
| state.embedder, state.generator = await asyncio.gather( | |
| asyncio.to_thread(SentenceTransformer, EMBED_MODEL, device=DEVICE) if ST_OK | |
| else asyncio.to_thread(lambda: None), | |
| asyncio.to_thread(_load_generator), | |
| ) | |
| if state.embedder: | |
| print(" Embedder ready.") | |
| # 2. Fetch API + GitHub concurrently | |
| print("Fetching knowledge base (API + GitHub)β¦") | |
| api_docs, gh_docs = await asyncio.gather( | |
| load_api(), load_github(), return_exceptions=False, | |
| ) | |
| raw_docs = api_docs + gh_docs | |
| print(f" Combined raw docs: {len(raw_docs)}") | |
| if not raw_docs: | |
| print("[WARN] No documents fetched.") | |
| # 3. Chunk | |
| state.documents = await asyncio.to_thread(chunk_documents, raw_docs) | |
| print(f" Total chunks: {len(state.documents)}") | |
| # 4. Build indexes | |
| print("Building indexesβ¦") | |
| await asyncio.to_thread(build_indexes) | |
| state.ready = True | |
| print("β RAG server ready.") | |
| except Exception as e: | |
| state.error = str(e) | |
| state.ready = False | |
| print(f"[ERROR] Boot failed: {e}") | |
| import traceback; traceback.print_exc() | |
| async def lifespan(app: FastAPI): | |
| boot_task = asyncio.create_task(_boot()) | |
| try: | |
| yield | |
| finally: | |
| boot_task.cancel() | |
| try: | |
| await boot_task | |
| except asyncio.CancelledError: | |
| pass | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # APP + ENDPOINTS | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| app = FastAPI(title="EWU RAG Server", lifespan=lifespan) | |
| class Query(BaseModel): | |
| query : str | |
| top_k : int = 5 | |
| async def rag_endpoint(q: Query): | |
| if not state.ready: | |
| raise HTTPException(503, detail=state.error or "Still initializing β retry shortly.") | |
| if not q.query.strip(): | |
| raise HTTPException(400, detail="Query must not be empty.") | |
| results = await async_hybrid_search(q.query, k=q.top_k) | |
| if not results: | |
| return {"answer": "No relevant information found.", "sources": []} | |
| context = "\n\n---\n\n".join(r["content"] for r in results) | |
| answer = await generate(q.query, context) | |
| return { | |
| "answer": answer, | |
| "sources": [{"source": r.get("source"), "rrf_score": r.get("rrf_score", 0)} | |
| for r in results], | |
| } | |
| async def health(): | |
| return JSONResponse(200, { | |
| "status" : "ready" if state.ready else ("error" if state.error else "loading"), | |
| "docs" : len(state.documents), | |
| "device" : DEVICE, | |
| "faiss" : state.faiss_index is not None, | |
| "bm25" : state.bm25 is not None, | |
| "generator" : state.generator is not None, | |
| "error" : state.error or None, | |
| }) | |
| if __name__ == "__main__": | |
| uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=False) |