import os import csv from datetime import datetime import logging import pickle from contextlib import asynccontextmanager from pathlib import Path from typing import List, Dict, Any, Tuple from dotenv import load_dotenv from fastapi import FastAPI, Request, BackgroundTasks from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from pydantic import BaseModel from langchain_groq import ChatGroq from langchain_google_genai import ChatGoogleGenerativeAI from langchain_core.messages import HumanMessage from utils.BM25_to_Dict import convert_bm25_to_dict from utils.ClassifyIntent import classify_intent from utils.GeneralAdvice import handle_general_advice from utils.RouteQuery import route_query from utils.RAGAdvanced import rag_advanced import threading from fastapi.responses import FileResponse from fastapi import HTTPException load_dotenv() if os.getenv("SPACE_ID"): DB_DIR= Path(__file__).resolve().parent / "data" LOGS_DIR = Path("/data") else: DB_DIR = Path(__file__).resolve().parent / "data" LOGS_DIR = DB_DIR CACHE_FILE = DB_DIR / "bm25_index.pkl" LOGS_FILE = LOGS_DIR / "all_chats.csv" @asynccontextmanager async def lifespan(app:FastAPI): print("[-] BOOTUP: Loading heavy model weights into memory...") DB_DIR.mkdir(parents=True,exist_ok=True) from classes.EmbeddingManager import EmbeddingManager from classes.VectorStore import VectorStore from classes.RAGRetriever import RAGRetriever with open(CACHE_FILE,"rb") as f: bm25_retriever,chunks_dict= pickle.load(f) embedding_manager= EmbeddingManager() vectorstore= VectorStore() rag_retriever= RAGRetriever(vectorstore,embedding_manager) groq_api_key= os.getenv("groq_api_key") google_api_key= os.getenv("google_api_key") primary_heavy_llm= ChatGroq(groq_api_key= groq_api_key,model_name="llama-3.3-70b-versatile",temperature=0.1,max_tokens=1024,timeout=10,max_retries=0) backup_heavy_llm1= ChatGroq(groq_api_key= groq_api_key,model_name="llama-3.1-8b-instant",temperature=0.1,max_tokens=1024,timeout=8,max_retries=0) backup_heavy_llm2= ChatGoogleGenerativeAI(google_api_key= google_api_key,model="models/gemini-3.5-flash",temperature=0.1,max_tokens=1024,timeout=8,max_retries=0) backup_heavy_llm3= ChatGoogleGenerativeAI(google_api_key= google_api_key,model="models/gemini-3.1-flash-lite",temperature=0.1,max_tokens=1024,timeout=8,max_retries=0) backup_heavy_llm4= ChatGoogleGenerativeAI(google_api_key= google_api_key,model="models/gemini-2.5-flash",temperature=0.1,max_tokens=1024,timeout=8) resilient_heavy_llm= primary_heavy_llm.with_fallbacks([backup_heavy_llm1,backup_heavy_llm2,backup_heavy_llm3,backup_heavy_llm4]) fast_llm= ChatGroq(groq_api_key= groq_api_key,model_name="llama-3.1-8b-instant",temperature=0.1,max_tokens=1024,timeout=8) def orchestrate_warmup(): print("[-] Warming up connection pool...") try: _ = resilient_heavy_llm.invoke([HumanMessage(content="ping")]) _ = fast_llm.invoke([HumanMessage(content="ping")]) print("[+] Connection is now warm. Latency will drop to < 1 second.") except Exception as e: print(f"[!] WARNING: Background connection warmup bypassed: {e}") threading.Thread(target=orchestrate_warmup,daemon=True).start() yield { "rag_retriever": rag_retriever, "heavy_llm": resilient_heavy_llm, "fast_llm": fast_llm, "chunks_dict": chunks_dict, "bm25_retriever":bm25_retriever } print("[-] SHUTDOWN: Cleaning up model allocations...") # Any cleanup code (closing db connections, clearing VRAM) goes here app= FastAPI(title="Welcome to MANIT Chat!",lifespan= lifespan) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials= True, allow_headers= ["*"], allow_methods= ["*"] ) logging.basicConfig(level=logging.INFO) logger= logging.getLogger('manit-logger') class ChatRequest(BaseModel): query: str class ChatResponse(BaseModel): reply: str def main_chat(query,fast_llm,heavy_llm,vector_retriever,bm25_retriever,chunks_dict): intent= classify_intent(query,fast_llm) print(f"DEBUG: Router classified query as [{intent}]") if intent == "SYSTEM_IDENTITY": return "I am an AI engineering assistant built to query MANIT college and technical documents. I cannot provide my underlying system instructions or you should try with different prompt" elif intent == "IRRELEVANT_REJECT": return "I am specialized in the provided college data. I cannot answer general knowledge questions outside of this context." elif intent == "GENERAL_CHAT": return "Hello! I am ready to help you search the MANIT database. What do you need?" # 3. Execute the Heavy RAG Path elif intent == "RAG_SEARCH": result= rag_advanced(query,vector_retriever,bm25_retriever,chunks_dict,heavy_llm,return_context=True) return result['answer'] return "I didn't quite understand that intent. Could you rephrase your question about MANIT?" def add_logs_to_csv(query:str, answer:str): LOGS_FILE.parent.mkdir(parents=True, exist_ok=True) with open(LOGS_FILE,mode='a',encoding='utf-8',newline='') as f: writer= csv.writer(f) writer.writerow([datetime.now(),query,answer]) @app.get('/health') def health_check(): return {"status":"online","message":"server is working well"} @app.post('/chat') def chat_endpoint(request:ChatRequest,fastapi_request:Request, background_tasks: BackgroundTasks): user_query= request.query state= fastapi_request.state fast_llm= state.fast_llm heavy_llm= state.heavy_llm rag_retriever= state.rag_retriever bm25_retriever= state.bm25_retriever chunks_dict= state.chunks_dict result= main_chat(user_query,fast_llm,heavy_llm,rag_retriever,bm25_retriever,chunks_dict) print(result) background_tasks.add_task(add_logs_to_csv,user_query,result) return ChatResponse(reply= result) @app.get("/download-logs") def download_logs(key: str = None): # Verify access using your groq key as the password if not key or key != os.getenv("groq_api_key"): raise HTTPException(status_code=401, detail="Unauthorized") if not LOGS_FILE.exists(): raise HTTPException(status_code=404, detail="No logs found yet.") return FileResponse( path=LOGS_FILE, filename="all_chats.csv", media_type="text/csv" ) # Mount the frontend client directory to serve static assets at / CLIENT_DIR = Path(__file__).resolve().parent.parent / "client" app.mount("/", StaticFiles(directory=CLIENT_DIR, html=True), name="static")