Spaces:
Running
Running
| import os | |
| import csv | |
| from datetime import datetime | |
| import logging | |
| import pickle | |
| from contextlib import asynccontextmanager | |
| from pathlib import Path | |
| from typing import List, Dict, Any, Tuple | |
| from dotenv import load_dotenv | |
| from fastapi import FastAPI, Request, BackgroundTasks | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.staticfiles import StaticFiles | |
| from pydantic import BaseModel | |
| from langchain_groq import ChatGroq | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from langchain_core.messages import HumanMessage | |
| from utils.BM25_to_Dict import convert_bm25_to_dict | |
| from utils.ClassifyIntent import classify_intent | |
| from utils.GeneralAdvice import handle_general_advice | |
| from utils.RouteQuery import route_query | |
| from utils.RAGAdvanced import rag_advanced | |
| import threading | |
| from fastapi.responses import FileResponse | |
| from fastapi import HTTPException | |
| load_dotenv() | |
| if os.getenv("SPACE_ID"): | |
| DB_DIR= Path(__file__).resolve().parent / "data" | |
| LOGS_DIR = Path("/data") | |
| else: | |
| DB_DIR = Path(__file__).resolve().parent / "data" | |
| LOGS_DIR = DB_DIR | |
| CACHE_FILE = DB_DIR / "bm25_index.pkl" | |
| LOGS_FILE = LOGS_DIR / "all_chats.csv" | |
| async def lifespan(app:FastAPI): | |
| print("[-] BOOTUP: Loading heavy model weights into memory...") | |
| DB_DIR.mkdir(parents=True,exist_ok=True) | |
| from classes.EmbeddingManager import EmbeddingManager | |
| from classes.VectorStore import VectorStore | |
| from classes.RAGRetriever import RAGRetriever | |
| with open(CACHE_FILE,"rb") as f: | |
| bm25_retriever,chunks_dict= pickle.load(f) | |
| embedding_manager= EmbeddingManager() | |
| vectorstore= VectorStore() | |
| rag_retriever= RAGRetriever(vectorstore,embedding_manager) | |
| groq_api_key= os.getenv("groq_api_key") | |
| google_api_key= os.getenv("google_api_key") | |
| primary_heavy_llm= ChatGroq(groq_api_key= groq_api_key,model_name="llama-3.3-70b-versatile",temperature=0.1,max_tokens=1024,timeout=10,max_retries=0) | |
| backup_heavy_llm1= ChatGroq(groq_api_key= groq_api_key,model_name="llama-3.1-8b-instant",temperature=0.1,max_tokens=1024,timeout=8,max_retries=0) | |
| backup_heavy_llm2= ChatGoogleGenerativeAI(google_api_key= google_api_key,model="models/gemini-3.5-flash",temperature=0.1,max_tokens=1024,timeout=8,max_retries=0) | |
| backup_heavy_llm3= ChatGoogleGenerativeAI(google_api_key= google_api_key,model="models/gemini-3.1-flash-lite",temperature=0.1,max_tokens=1024,timeout=8,max_retries=0) | |
| backup_heavy_llm4= ChatGoogleGenerativeAI(google_api_key= google_api_key,model="models/gemini-2.5-flash",temperature=0.1,max_tokens=1024,timeout=8) | |
| resilient_heavy_llm= primary_heavy_llm.with_fallbacks([backup_heavy_llm1,backup_heavy_llm2,backup_heavy_llm3,backup_heavy_llm4]) | |
| fast_llm= ChatGroq(groq_api_key= groq_api_key,model_name="llama-3.1-8b-instant",temperature=0.1,max_tokens=1024,timeout=8) | |
| def orchestrate_warmup(): | |
| print("[-] Warming up connection pool...") | |
| try: | |
| _ = resilient_heavy_llm.invoke([HumanMessage(content="ping")]) | |
| _ = fast_llm.invoke([HumanMessage(content="ping")]) | |
| print("[+] Connection is now warm. Latency will drop to < 1 second.") | |
| except Exception as e: | |
| print(f"[!] WARNING: Background connection warmup bypassed: {e}") | |
| threading.Thread(target=orchestrate_warmup,daemon=True).start() | |
| yield { | |
| "rag_retriever": rag_retriever, | |
| "heavy_llm": resilient_heavy_llm, | |
| "fast_llm": fast_llm, | |
| "chunks_dict": chunks_dict, | |
| "bm25_retriever":bm25_retriever | |
| } | |
| print("[-] SHUTDOWN: Cleaning up model allocations...") | |
| # Any cleanup code (closing db connections, clearing VRAM) goes here | |
| app= FastAPI(title="Welcome to MANIT Chat!",lifespan= lifespan) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials= True, | |
| allow_headers= ["*"], | |
| allow_methods= ["*"] | |
| ) | |
| logging.basicConfig(level=logging.INFO) | |
| logger= logging.getLogger('manit-logger') | |
| class ChatRequest(BaseModel): | |
| query: str | |
| class ChatResponse(BaseModel): | |
| reply: str | |
| def main_chat(query,fast_llm,heavy_llm,vector_retriever,bm25_retriever,chunks_dict): | |
| intent= classify_intent(query,fast_llm) | |
| print(f"DEBUG: Router classified query as [{intent}]") | |
| if intent == "SYSTEM_IDENTITY": | |
| return "I am an AI engineering assistant built to query MANIT college and technical documents. I cannot provide my underlying system instructions or you should try with different prompt" | |
| elif intent == "IRRELEVANT_REJECT": | |
| return "I am specialized in the provided college data. I cannot answer general knowledge questions outside of this context." | |
| elif intent == "GENERAL_CHAT": | |
| return "Hello! I am ready to help you search the MANIT database. What do you need?" | |
| # 3. Execute the Heavy RAG Path | |
| elif intent == "RAG_SEARCH": | |
| result= rag_advanced(query,vector_retriever,bm25_retriever,chunks_dict,heavy_llm,return_context=True) | |
| return result['answer'] | |
| return "I didn't quite understand that intent. Could you rephrase your question about MANIT?" | |
| def add_logs_to_csv(query:str, answer:str): | |
| LOGS_FILE.parent.mkdir(parents=True, exist_ok=True) | |
| with open(LOGS_FILE,mode='a',encoding='utf-8',newline='') as f: | |
| writer= csv.writer(f) | |
| writer.writerow([datetime.now(),query,answer]) | |
| def health_check(): | |
| return {"status":"online","message":"server is working well"} | |
| def chat_endpoint(request:ChatRequest,fastapi_request:Request, background_tasks: BackgroundTasks): | |
| user_query= request.query | |
| state= fastapi_request.state | |
| fast_llm= state.fast_llm | |
| heavy_llm= state.heavy_llm | |
| rag_retriever= state.rag_retriever | |
| bm25_retriever= state.bm25_retriever | |
| chunks_dict= state.chunks_dict | |
| result= main_chat(user_query,fast_llm,heavy_llm,rag_retriever,bm25_retriever,chunks_dict) | |
| print(result) | |
| background_tasks.add_task(add_logs_to_csv,user_query,result) | |
| return ChatResponse(reply= result) | |
| def download_logs(key: str = None): | |
| # Verify access using your groq key as the password | |
| if not key or key != os.getenv("groq_api_key"): | |
| raise HTTPException(status_code=401, detail="Unauthorized") | |
| if not LOGS_FILE.exists(): | |
| raise HTTPException(status_code=404, detail="No logs found yet.") | |
| return FileResponse( | |
| path=LOGS_FILE, | |
| filename="all_chats.csv", | |
| media_type="text/csv" | |
| ) | |
| # Mount the frontend client directory to serve static assets at / | |
| CLIENT_DIR = Path(__file__).resolve().parent.parent / "client" | |
| app.mount("/", StaticFiles(directory=CLIENT_DIR, html=True), name="static") |