rag_server / app.py
atkiya110's picture
Update app.py
34ca45d verified
import os
import json
import asyncio
import numpy as np
import uvicorn
import httpx
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
try:
import faiss
FAISS_OK = True
except ImportError:
FAISS_OK = False
try:
from sentence_transformers import SentenceTransformer
ST_OK = True
except ImportError:
ST_OK = False
try:
from rank_bm25 import BM25Okapi
BM25_OK = True
except ImportError:
BM25_OK = False
try:
from transformers import pipeline as hf_pipeline
HF_OK = True
except ImportError:
HF_OK = False
print("[WARN] transformers not installed β€” generation disabled.")
# ─────────────────────────────────────────────
# CONFIG
# ─────────────────────────────────────────────
API_BASE = "https://ewu-server.onrender.com/api"
API_KEY = "i6EDytaX4E2jI6GvZQc0b1RSZHTI5_wVRa2rfL7rLpk"
API_HEADERS = {"x-api-key": API_KEY}
GITHUB_BASE = "https://raw.githubusercontent.com/Atkiya/jsonfiles/main/"
EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
GEN_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
CHUNK_SIZE = 400
CHUNK_OVERLAP = 80
DEVICE = "cpu"
try:
import torch
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
except ImportError:
pass
API_LIST_ENDPOINTS = [
"admission-deadlines", "academic-calendar", "grade-scale",
"tuition-fees", "scholarships", "clubs", "notices", "partnerships",
"governance", "alumni", "helpdesk", "policies", "proctor-schedule",
"documents", "newsletters", "programs", "faculty", "departments",
]
API_DETAIL_ENDPOINTS = [
{"list": "programs", "id_field": "id"},
{"list": "faculty", "id_field": "id"},
{"list": "documents", "id_field": "slug"},
]
GITHUB_FILES = [
"admission_deadlines.json", "dynamic_admission_process.json",
"dynamic_admission_requirements.json", "dynamic_tution_fees.json",
"dynamic_events_workshops.json", "ewu_faculty_complete.json",
"dynamic_grading.json", "dynamic_facilites.json",
"ewu_proctor_schedule.json", "ewu_newsletters_complete.json",
"static_aboutEWU.json", "static_Admin.json",
"static_AllAvailablePrograms.json", "static_alumni.json",
"static_campus_life.json", "static_Career_Counseling_Center.json",
"static_clubs.json", "static_depts.json", "static_facilities.json",
"static_helpdesk.json", "static_payment_procedure.json",
"static_Policy.json", "static_Programs.json", "static_Rules.json",
"static_Sexual_harassment.json", "static_Tuition_fees.json",
"ma_english.json", "mba_emba.json", "ms_cse.json", "ms_dsa.json",
"mds.json", "mphil_pharmacy.json", "mss_eco.json",
"scholarships_and_financial_aids.json",
"st_ba.json", "st_ce.json", "st_cse.json", "st_ece.json",
"st_economics.json", "st_eee.json", "st_english.json", "st_geb.json",
"st_information_studies.json", "st_law.json", "st_math.json",
"st_pharmacy.json", "st_social_relations.json", "st_sociology.json",
"syndicate.json", "tesol.json", "ewu_board_of_trustees.json",
]
# ─────────────────────────────────────────────
# APP STATE
# ─────────────────────────────────────────────
class AppState:
embedder = None
generator = None # TinyLlama pipeline
documents : list = []
faiss_index = None
bm25 = None
ready : bool = False
error : str = ""
state = AppState()
# ─────────────────────────────────────────────
# DATA LOADING
# ─────────────────────────────────────────────
async def fetch_json(url: str, headers: dict = None, timeout: int = 60):
try:
async with httpx.AsyncClient(timeout=timeout) as client:
r = await client.get(url, headers=headers or {})
if r.status_code == 200:
return r.json()
print(f"[WARN] {url} β†’ HTTP {r.status_code}")
except Exception as e:
print(f"[WARN] {url} β†’ {e}")
return None
def _unwrap(data) -> list:
if isinstance(data, list):
return data
if isinstance(data, dict):
for key in ("data", "results", "items"):
if key in data and isinstance(data[key], list):
return data[key]
return [data]
return []
async def _wake_api_server():
"""
render.com free tier sleeps after inactivity.
Hit a cheap endpoint first and wait for it to wake up (can take ~50s).
"""
print(" [API] Waking render.com server (free tier may be sleeping)…")
for attempt in range(3):
result = await fetch_json(f"{API_BASE}/grade-scale", API_HEADERS, timeout=60)
if result is not None:
print(" [API] Server awake.")
return True
print(f" [API] Wake attempt {attempt+1}/3 failed, retrying…")
await asyncio.sleep(10)
print(" [API] Server did not wake β€” skipping API data.")
return False
async def load_api() -> list:
awake = await _wake_api_server()
if not awake:
return []
list_results = await asyncio.gather(
*[fetch_json(f"{API_BASE}/{ep}", API_HEADERS) for ep in API_LIST_ENDPOINTS],
return_exceptions=True,
)
docs, list_cache = [], {}
for ep, data in zip(API_LIST_ENDPOINTS, list_results):
if not data or isinstance(data, Exception):
continue
items = _unwrap(data)
list_cache[ep] = items
for item in items:
text = json.dumps(item, ensure_ascii=False)
if text.strip():
docs.append({"content": text, "source": f"api:{ep}"})
print(f" [API lists] {len(docs)} docs")
detail_tasks = []
for cfg in API_DETAIL_ENDPOINTS:
for item in list_cache.get(cfg["list"], []):
item_id = item.get(cfg["id_field"]) if isinstance(item, dict) else None
if item_id is not None:
url = f"{API_BASE}/{cfg['list']}/{item_id}"
detail_tasks.append((url, f"api:{cfg['list']}/{item_id}"))
if detail_tasks:
detail_results = await asyncio.gather(
*[fetch_json(url, API_HEADERS) for url, _ in detail_tasks],
return_exceptions=True,
)
n = 0
for (_, source), data in zip(detail_tasks, detail_results):
if not data or isinstance(data, Exception):
continue
for item in _unwrap(data):
text = json.dumps(item, ensure_ascii=False)
if text.strip():
docs.append({"content": text, "source": source})
n += 1
print(f" [API details] {n} docs from {len(detail_tasks)} pages")
print(f" [API total] {len(docs)} raw docs")
return docs
async def load_github() -> list:
responses = await asyncio.gather(
*[fetch_json(GITHUB_BASE + f) for f in GITHUB_FILES],
return_exceptions=True,
)
docs = []
for fname, data in zip(GITHUB_FILES, responses):
if not data or isinstance(data, Exception):
continue
for item in (data if isinstance(data, list) else [data]):
text = json.dumps(item, ensure_ascii=False)
if text.strip():
docs.append({"content": text, "source": f"github:{fname}"})
print(f" [GitHub] {len(docs)} raw docs")
return docs
# ─────────────────────────────────────────────
# CHUNKING
# ─────────────────────────────────────────────
def chunk_documents(docs, size=CHUNK_SIZE, overlap=CHUNK_OVERLAP):
step, out = max(1, size - overlap), []
for d in docs:
text = d["content"]
if not text.strip():
continue
if len(text) <= size:
out.append(d)
continue
start = 0
while start < len(text):
chunk = text[start:start+size]
if chunk.strip():
out.append({"content": chunk, "source": d["source"]})
start += step
return out
# ─────────────────────────────────────────────
# INDEX BUILDING
# ─────────────────────────────────────────────
def build_indexes():
if not state.documents:
print("[WARN] No documents to index.")
return False
texts = [d["content"] for d in state.documents]
if FAISS_OK and ST_OK and state.embedder:
try:
emb = state.embedder.encode(
texts, normalize_embeddings=True,
show_progress_bar=False, batch_size=64,
)
emb = np.array(emb, dtype="float32")
if emb.ndim == 2 and emb.shape[0] > 0:
state.faiss_index = faiss.IndexFlatIP(emb.shape[1])
state.faiss_index.add(emb)
print(f" [FAISS] {state.faiss_index.ntotal} vectors (dim={emb.shape[1]})")
except Exception as e:
print(f"[ERROR] FAISS: {e}")
state.faiss_index = None
if BM25_OK:
try:
tok = [t.lower().split() for t in texts if t.strip()]
if tok:
state.bm25 = BM25Okapi(tok)
print(f" [BM25] {len(tok)} docs")
except Exception as e:
print(f"[ERROR] BM25: {e}")
state.bm25 = None
return True
# ─────────────────────────────────────────────
# RETRIEVAL
# ─────────────────────────────────────────────
def search_dense(query, k=8):
if not state.faiss_index or not state.embedder:
return []
try:
vec = np.array(
state.embedder.encode([query], normalize_embeddings=True), dtype="float32"
)
k_a = min(k, state.faiss_index.ntotal)
if not k_a:
return []
scores, ids = state.faiss_index.search(vec, k_a)
return [{**state.documents[i], "score": float(s)}
for s, i in zip(scores[0], ids[0]) if i >= 0]
except Exception as e:
print(f"[ERROR] dense: {e}")
return []
def search_sparse(query, k=8):
if not state.bm25 or not state.documents:
return []
try:
tokens = query.lower().split()
if not tokens:
return []
scores = np.array(state.bm25.get_scores(tokens), dtype="float32")
idx = np.argsort(scores)[::-1][:min(k, len(scores))]
return [{**state.documents[i], "score": float(scores[i])}
for i in idx if scores[i] > 0]
except Exception as e:
print(f"[ERROR] sparse: {e}")
return []
def hybrid_search(query, k=5, alpha=0.65):
dense = search_dense(query, k * 3)
sparse = search_sparse(query, k * 3)
if not dense and not sparse:
return []
rrf_k, merged, doc_map = 60, {}, {}
for rank, d in enumerate(dense):
key = d["content"]
merged[key] = merged.get(key, 0.0) + alpha / (rrf_k + rank + 1)
doc_map[key] = d
for rank, d in enumerate(sparse):
key = d["content"]
merged[key] = merged.get(key, 0.0) + (1 - alpha) / (rrf_k + rank + 1)
doc_map[key] = d
return [{**doc_map[c], "rrf_score": round(s, 6)}
for c, s in sorted(merged.items(), key=lambda x: x[1], reverse=True)[:k]]
async def async_hybrid_search(query, k=5):
return await asyncio.to_thread(hybrid_search, query, k)
# ─────────────────────────────────────────────
# GENERATION β€” TinyLlama (local, no API key)
# ─────────────────────────────────────────────
SYSTEM_PROMPT = (
"You are a helpful assistant for East West University (EWU). "
"Answer using ONLY the context provided. "
"If the context does not contain enough information, say so honestly. "
"Be concise and accurate. Do not repeat the context."
)
def _run_tinyllama(query: str, context: str) -> str:
"""
Synchronous TinyLlama call.
Uses the chat template format TinyLlama-1.1B-Chat was trained on.
Always call via asyncio.to_thread β€” never directly from async code.
"""
if state.generator is None:
return f"[Generator not loaded]\n\nContext:\n{context}"
# TinyLlama chat template: <|system|>...<|user|>...<|assistant|>
# Trim context to ~1500 chars so it fits in the 2048-token window
trimmed_context = context[:1500] + ("…" if len(context) > 1500 else "")
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": f"Context:\n{trimmed_context}\n\nQuestion: {query}"},
]
try:
outputs = state.generator(
messages,
max_new_tokens=256,
do_sample=True,
temperature=0.3, # low = more factual, less hallucination
top_p=0.9,
repetition_penalty=1.1,
)
# transformers pipeline returns list of dicts with generated_text
generated = outputs[0]["generated_text"]
# generated_text is the full conversation list; grab the last assistant turn
if isinstance(generated, list):
for turn in reversed(generated):
if isinstance(turn, dict) and turn.get("role") == "assistant":
return turn.get("content", "").strip()
# Fallback: return raw string
return str(generated).strip()
except Exception as e:
print(f"[ERROR] TinyLlama inference: {e}")
return f"[Generation error: {e}]"
async def generate(query: str, context: str) -> str:
"""Async wrapper β€” runs TinyLlama in a thread so the event loop stays free."""
return await asyncio.to_thread(_run_tinyllama, query, context)
# ─────────────────────────────────────────────
# BOOT
# ─────────────────────────────────────────────
def _load_generator():
"""Load TinyLlama pipeline. Runs in a thread during boot."""
if not HF_OK:
print("[WARN] transformers unavailable β€” generation disabled.")
return None
try:
print(f" Loading TinyLlama on {DEVICE}…")
gen = hf_pipeline(
"text-generation",
model=GEN_MODEL,
device=0 if DEVICE == "cuda" else -1, # -1 = CPU for transformers pipeline
dtype="auto",
)
print(" TinyLlama ready.")
return gen
except Exception as e:
print(f"[ERROR] Could not load TinyLlama: {e}")
return None
async def _boot():
try:
# 1. Load both models concurrently in threads
print(f"Loading models on {DEVICE}…")
state.embedder, state.generator = await asyncio.gather(
asyncio.to_thread(SentenceTransformer, EMBED_MODEL, device=DEVICE) if ST_OK
else asyncio.to_thread(lambda: None),
asyncio.to_thread(_load_generator),
)
if state.embedder:
print(" Embedder ready.")
# 2. Fetch API + GitHub concurrently
print("Fetching knowledge base (API + GitHub)…")
api_docs, gh_docs = await asyncio.gather(
load_api(), load_github(), return_exceptions=False,
)
raw_docs = api_docs + gh_docs
print(f" Combined raw docs: {len(raw_docs)}")
if not raw_docs:
print("[WARN] No documents fetched.")
# 3. Chunk
state.documents = await asyncio.to_thread(chunk_documents, raw_docs)
print(f" Total chunks: {len(state.documents)}")
# 4. Build indexes
print("Building indexes…")
await asyncio.to_thread(build_indexes)
state.ready = True
print("βœ“ RAG server ready.")
except Exception as e:
state.error = str(e)
state.ready = False
print(f"[ERROR] Boot failed: {e}")
import traceback; traceback.print_exc()
@asynccontextmanager
async def lifespan(app: FastAPI):
boot_task = asyncio.create_task(_boot())
try:
yield
finally:
boot_task.cancel()
try:
await boot_task
except asyncio.CancelledError:
pass
# ─────────────────────────────────────────────
# APP + ENDPOINTS
# ─────────────────────────────────────────────
app = FastAPI(title="EWU RAG Server", lifespan=lifespan)
class Query(BaseModel):
query : str
top_k : int = 5
@app.post("/rag")
async def rag_endpoint(q: Query):
if not state.ready:
raise HTTPException(503, detail=state.error or "Still initializing β€” retry shortly.")
if not q.query.strip():
raise HTTPException(400, detail="Query must not be empty.")
results = await async_hybrid_search(q.query, k=q.top_k)
if not results:
return {"answer": "No relevant information found.", "sources": []}
context = "\n\n---\n\n".join(r["content"] for r in results)
answer = await generate(q.query, context)
return {
"answer": answer,
"sources": [{"source": r.get("source"), "rrf_score": r.get("rrf_score", 0)}
for r in results],
}
@app.get("/health")
async def health():
return JSONResponse(200, {
"status" : "ready" if state.ready else ("error" if state.error else "loading"),
"docs" : len(state.documents),
"device" : DEVICE,
"faiss" : state.faiss_index is not None,
"bm25" : state.bm25 is not None,
"generator" : state.generator is not None,
"error" : state.error or None,
})
if __name__ == "__main__":
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=False)