| """ |
| EpiRAG -- server.py |
| ------------------ |
| Flask server. Detects environment and chooses ChromaDB mode: |
| |
| LOCAL (EPIRAG_ENV != "cloud"): |
| Reads from persistent ./chroma_db/ - fast restarts, no re-embedding. |
| |
| CLOUD (EPIRAG_ENV=cloud, e.g. HF Spaces): |
| Builds in-memory ChromaDB from ./papers/ at startup. |
| Takes ~2-3 min cold start. No disk writes needed. |
| """ |
|
|
| import os |
| import time |
| import chromadb |
| from flask import Flask, jsonify, request, send_from_directory |
| from flask_cors import CORS |
| from query import rag_query, set_components |
|
|
| app = Flask(__name__, static_folder="static") |
| CORS(app) |
|
|
| COLLECTION_NAME = "epirag" |
| IS_CLOUD = os.environ.get("EPIRAG_ENV", "").lower() == "cloud" |
|
|
| |
| _collection = None |
| _embedder = None |
| CORPUS_STATS = {} |
|
|
| def init_corpus(): |
| global _collection, _embedder, CORPUS_STATS |
|
|
| if IS_CLOUD: |
| print("Cloud mode — building in-memory corpus from ./papers/", flush=True) |
| from ingest import build_collection_in_memory |
| _collection, _embedder = build_collection_in_memory() |
| else: |
| print("Local mode — loading from ./chroma_db/", flush=True) |
| from sentence_transformers import SentenceTransformer |
| client = chromadb.PersistentClient(path="./chroma_db") |
| _collection = client.get_collection(COLLECTION_NAME) |
| _embedder = SentenceTransformer("all-MiniLM-L6-v2") |
|
|
| |
| set_components(_embedder, _collection) |
|
|
| |
| count = _collection.count() |
| results = _collection.get(limit=count, include=["metadatas"]) |
| papers = sorted(set( |
| m.get("paper_name", m.get("source", "Unknown")) |
| for m in results["metadatas"] |
| )) |
| CORPUS_STATS.update({ |
| "chunks": count, |
| "papers": len(papers), |
| "paperList": papers, |
| "status": "online", |
| "mode": "cloud (in-memory)" if IS_CLOUD else "local (persistent)" |
| }) |
| print(f"Corpus ready: {count} chunks / {len(papers)} papers", flush=True) |
|
|
|
|
| init_corpus() |
|
|
|
|
| |
| @app.route("/") |
| def index(): |
| return send_from_directory("static", "index.html") |
|
|
|
|
| @app.route("/api/stats") |
| def stats(): |
| return jsonify(CORPUS_STATS) |
|
|
|
|
| @app.route("/api/query", methods=["POST"]) |
| def query(): |
| data = request.json or {} |
| question = (data.get("question") or "").strip() |
| if not question: |
| return jsonify({"error": "No question provided"}), 400 |
|
|
| groq_key = os.environ.get("GROQ_API_KEY") |
| tavily_key = os.environ.get("TAVILY_API_KEY") |
| if not groq_key: |
| return jsonify({"error": "GROQ_API_KEY not set on server"}), 500 |
|
|
| start = time.time() |
| result = rag_query(question, groq_key, tavily_key) |
| elapsed_ms = int((time.time() - start) * 1000) |
|
|
| return jsonify({ |
| "answer": result["answer"], |
| "sources": result["sources"], |
| "mode": result["mode"], |
| "avg_sim": result["avg_sim"], |
| "latency_ms": elapsed_ms, |
| "tokens": len(result["answer"]) // 4, |
| "question": question |
| }) |
|
|
|
|
| @app.route("/api/health") |
| def health(): |
| return jsonify({"status": "ok", "corpus": CORPUS_STATS.get("status", "unknown")}) |
|
|
|
|
| if __name__ == "__main__": |
| port = int(os.environ.get("PORT", 7860)) |
| app.run(debug=False, host="0.0.0.0", port=port) |