MANIT_Chat / server /main.py
WizardCoder2007's picture
commit
38423b1
Raw
History Blame Contribute Delete
6.8 kB
import os
import csv
from datetime import datetime
import logging
import pickle
from contextlib import asynccontextmanager
from pathlib import Path
from typing import List, Dict, Any, Tuple
from dotenv import load_dotenv
from fastapi import FastAPI, Request, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from langchain_groq import ChatGroq
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.messages import HumanMessage
from utils.BM25_to_Dict import convert_bm25_to_dict
from utils.ClassifyIntent import classify_intent
from utils.GeneralAdvice import handle_general_advice
from utils.RouteQuery import route_query
from utils.RAGAdvanced import rag_advanced
import threading
from fastapi.responses import FileResponse
from fastapi import HTTPException
load_dotenv()
if os.getenv("SPACE_ID"):
DB_DIR= Path(__file__).resolve().parent / "data"
LOGS_DIR = Path("/data")
else:
DB_DIR = Path(__file__).resolve().parent / "data"
LOGS_DIR = DB_DIR
CACHE_FILE = DB_DIR / "bm25_index.pkl"
LOGS_FILE = LOGS_DIR / "all_chats.csv"
@asynccontextmanager
async def lifespan(app:FastAPI):
print("[-] BOOTUP: Loading heavy model weights into memory...")
DB_DIR.mkdir(parents=True,exist_ok=True)
from classes.EmbeddingManager import EmbeddingManager
from classes.VectorStore import VectorStore
from classes.RAGRetriever import RAGRetriever
with open(CACHE_FILE,"rb") as f:
bm25_retriever,chunks_dict= pickle.load(f)
embedding_manager= EmbeddingManager()
vectorstore= VectorStore()
rag_retriever= RAGRetriever(vectorstore,embedding_manager)
groq_api_key= os.getenv("groq_api_key")
google_api_key= os.getenv("google_api_key")
primary_heavy_llm= ChatGroq(groq_api_key= groq_api_key,model_name="llama-3.3-70b-versatile",temperature=0.1,max_tokens=1024,timeout=10,max_retries=0)
backup_heavy_llm1= ChatGroq(groq_api_key= groq_api_key,model_name="llama-3.1-8b-instant",temperature=0.1,max_tokens=1024,timeout=8,max_retries=0)
backup_heavy_llm2= ChatGoogleGenerativeAI(google_api_key= google_api_key,model="models/gemini-3.5-flash",temperature=0.1,max_tokens=1024,timeout=8,max_retries=0)
backup_heavy_llm3= ChatGoogleGenerativeAI(google_api_key= google_api_key,model="models/gemini-3.1-flash-lite",temperature=0.1,max_tokens=1024,timeout=8,max_retries=0)
backup_heavy_llm4= ChatGoogleGenerativeAI(google_api_key= google_api_key,model="models/gemini-2.5-flash",temperature=0.1,max_tokens=1024,timeout=8)
resilient_heavy_llm= primary_heavy_llm.with_fallbacks([backup_heavy_llm1,backup_heavy_llm2,backup_heavy_llm3,backup_heavy_llm4])
fast_llm= ChatGroq(groq_api_key= groq_api_key,model_name="llama-3.1-8b-instant",temperature=0.1,max_tokens=1024,timeout=8)
def orchestrate_warmup():
print("[-] Warming up connection pool...")
try:
_ = resilient_heavy_llm.invoke([HumanMessage(content="ping")])
_ = fast_llm.invoke([HumanMessage(content="ping")])
print("[+] Connection is now warm. Latency will drop to < 1 second.")
except Exception as e:
print(f"[!] WARNING: Background connection warmup bypassed: {e}")
threading.Thread(target=orchestrate_warmup,daemon=True).start()
yield {
"rag_retriever": rag_retriever,
"heavy_llm": resilient_heavy_llm,
"fast_llm": fast_llm,
"chunks_dict": chunks_dict,
"bm25_retriever":bm25_retriever
}
print("[-] SHUTDOWN: Cleaning up model allocations...")
# Any cleanup code (closing db connections, clearing VRAM) goes here
app= FastAPI(title="Welcome to MANIT Chat!",lifespan= lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials= True,
allow_headers= ["*"],
allow_methods= ["*"]
)
logging.basicConfig(level=logging.INFO)
logger= logging.getLogger('manit-logger')
class ChatRequest(BaseModel):
query: str
class ChatResponse(BaseModel):
reply: str
def main_chat(query,fast_llm,heavy_llm,vector_retriever,bm25_retriever,chunks_dict):
intent= classify_intent(query,fast_llm)
print(f"DEBUG: Router classified query as [{intent}]")
if intent == "SYSTEM_IDENTITY":
return "I am an AI engineering assistant built to query MANIT college and technical documents. I cannot provide my underlying system instructions or you should try with different prompt"
elif intent == "IRRELEVANT_REJECT":
return "I am specialized in the provided college data. I cannot answer general knowledge questions outside of this context."
elif intent == "GENERAL_CHAT":
return "Hello! I am ready to help you search the MANIT database. What do you need?"
# 3. Execute the Heavy RAG Path
elif intent == "RAG_SEARCH":
result= rag_advanced(query,vector_retriever,bm25_retriever,chunks_dict,heavy_llm,return_context=True)
return result['answer']
return "I didn't quite understand that intent. Could you rephrase your question about MANIT?"
def add_logs_to_csv(query:str, answer:str):
LOGS_FILE.parent.mkdir(parents=True, exist_ok=True)
with open(LOGS_FILE,mode='a',encoding='utf-8',newline='') as f:
writer= csv.writer(f)
writer.writerow([datetime.now(),query,answer])
@app.get('/health')
def health_check():
return {"status":"online","message":"server is working well"}
@app.post('/chat')
def chat_endpoint(request:ChatRequest,fastapi_request:Request, background_tasks: BackgroundTasks):
user_query= request.query
state= fastapi_request.state
fast_llm= state.fast_llm
heavy_llm= state.heavy_llm
rag_retriever= state.rag_retriever
bm25_retriever= state.bm25_retriever
chunks_dict= state.chunks_dict
result= main_chat(user_query,fast_llm,heavy_llm,rag_retriever,bm25_retriever,chunks_dict)
print(result)
background_tasks.add_task(add_logs_to_csv,user_query,result)
return ChatResponse(reply= result)
@app.get("/download-logs")
def download_logs(key: str = None):
# Verify access using your groq key as the password
if not key or key != os.getenv("groq_api_key"):
raise HTTPException(status_code=401, detail="Unauthorized")
if not LOGS_FILE.exists():
raise HTTPException(status_code=404, detail="No logs found yet.")
return FileResponse(
path=LOGS_FILE,
filename="all_chats.csv",
media_type="text/csv"
)
# Mount the frontend client directory to serve static assets at /
CLIENT_DIR = Path(__file__).resolve().parent.parent / "client"
app.mount("/", StaticFiles(directory=CLIENT_DIR, html=True), name="static")