import json import os import re import sqlite3 from typing import TypedDict, Annotated from fastapi import FastAPI, WebSocket, WebSocketDisconnect from dotenv import load_dotenv from langchain_google_genai import ChatGoogleGenerativeAI from langchain_core.messages import ( HumanMessage, BaseMessage, SystemMessage, ) from langgraph.graph.message import add_messages from langgraph.graph import StateGraph, START, END from langgraph.checkpoint.sqlite import SqliteSaver # RAG imports from rag import load_student_documents, create_vectorstore, get_retriever load_dotenv(override=True) # ------------------ RAG SETUP ------------------ documents = load_student_documents("studentDataset.csv") vectorstore = create_vectorstore(documents) retriever = get_retriever(vectorstore) # ------------------ LLM ------------------ llm = ChatGoogleGenerativeAI( model="models/gemini-2.5-flash", temperature=0.3, ) # ------------------ STATE ------------------ class ChatState(TypedDict): messages: Annotated[list[BaseMessage], add_messages] # ------------------ GRAPH NODE ------------------ def chat_node(state: ChatState): user_msg = state["messages"][-1].content # Exact register number search match = re.search(r"\b\d{4,}\b", user_msg) docs = [] if match: reg_no = match.group() docs = [d for d in documents if reg_no in d.page_content] # Vector fallback if not docs: docs = retriever.invoke(user_msg) context = "\n\n".join(d.page_content for d in docs) system_prompt = f""" You are a student database assistant. Answer ONLY using the information below. If the answer is not present, say "I don't have that information". DATA: {context} """ response = llm.invoke([ SystemMessage(content=system_prompt), HumanMessage(content=user_msg), ]) return {"messages": [response]} # ------------------ GRAPH ------------------ graph = StateGraph(ChatState) graph.add_node("chat_node", chat_node) graph.add_edge(START, "chat_node") graph.add_edge("chat_node", END) conn = sqlite3.connect("a.db", check_same_thread=False) checkpointer = SqliteSaver(conn=conn) workflow = graph.compile(checkpointer=checkpointer) # ------------------ FASTAPI ------------------ app = FastAPI() @app.websocket("/chat") async def chat_ws(websocket: WebSocket): await websocket.accept() try: session_id = websocket.query_params.get("session_id", "default") while True: user_text = await websocket.receive_text() config = {"configurable": {"thread_id": session_id}} # STREAM RESPONSE for msg, _ in workflow.stream( {"messages": [HumanMessage(user_text)]}, config=config, stream_mode="messages", ): if msg.content: await websocket.send_json({ "type": "response", "content": msg.content }) # SIGNAL COMPLETION await websocket.send_json({ "type": "complete" }) except WebSocketDisconnect: print("Middleware disconnected") except Exception as e: print(f"Backend error: {e}") await websocket.send_json({ "type": "error", "message": "Backend error occurred" })