Spaces:
Sleeping
Sleeping
| import json | |
| import os | |
| import re | |
| import sqlite3 | |
| from typing import TypedDict, Annotated | |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect | |
| from dotenv import load_dotenv | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from langchain_core.messages import ( | |
| HumanMessage, | |
| BaseMessage, | |
| SystemMessage, | |
| ) | |
| from langgraph.graph.message import add_messages | |
| from langgraph.graph import StateGraph, START, END | |
| from langgraph.checkpoint.sqlite import SqliteSaver | |
| # RAG imports | |
| from rag import load_student_documents, create_vectorstore, get_retriever | |
| load_dotenv(override=True) | |
| # ------------------ RAG SETUP ------------------ | |
| documents = load_student_documents("studentDataset.csv") | |
| vectorstore = create_vectorstore(documents) | |
| retriever = get_retriever(vectorstore) | |
| # ------------------ LLM ------------------ | |
| llm = ChatGoogleGenerativeAI( | |
| model="models/gemini-2.5-flash", | |
| temperature=0.3, | |
| ) | |
| # ------------------ STATE ------------------ | |
| class ChatState(TypedDict): | |
| messages: Annotated[list[BaseMessage], add_messages] | |
| # ------------------ GRAPH NODE ------------------ | |
| def chat_node(state: ChatState): | |
| user_msg = state["messages"][-1].content | |
| # Exact register number search | |
| match = re.search(r"\b\d{4,}\b", user_msg) | |
| docs = [] | |
| if match: | |
| reg_no = match.group() | |
| docs = [d for d in documents if reg_no in d.page_content] | |
| # Vector fallback | |
| if not docs: | |
| docs = retriever.invoke(user_msg) | |
| context = "\n\n".join(d.page_content for d in docs) | |
| system_prompt = f""" | |
| You are a student database assistant. | |
| Answer ONLY using the information below. | |
| If the answer is not present, say "I don't have that information". | |
| DATA: | |
| {context} | |
| """ | |
| response = llm.invoke([ | |
| SystemMessage(content=system_prompt), | |
| HumanMessage(content=user_msg), | |
| ]) | |
| return {"messages": [response]} | |
| # ------------------ GRAPH ------------------ | |
| graph = StateGraph(ChatState) | |
| graph.add_node("chat_node", chat_node) | |
| graph.add_edge(START, "chat_node") | |
| graph.add_edge("chat_node", END) | |
| conn = sqlite3.connect("a.db", check_same_thread=False) | |
| checkpointer = SqliteSaver(conn=conn) | |
| workflow = graph.compile(checkpointer=checkpointer) | |
| # ------------------ FASTAPI ------------------ | |
| app = FastAPI() | |
| async def chat_ws(websocket: WebSocket): | |
| await websocket.accept() | |
| try: | |
| session_id = websocket.query_params.get("session_id", "default") | |
| while True: | |
| user_text = await websocket.receive_text() | |
| config = {"configurable": {"thread_id": session_id}} | |
| # STREAM RESPONSE | |
| for msg, _ in workflow.stream( | |
| {"messages": [HumanMessage(user_text)]}, | |
| config=config, | |
| stream_mode="messages", | |
| ): | |
| if msg.content: | |
| await websocket.send_json({ | |
| "type": "response", | |
| "content": msg.content | |
| }) | |
| # SIGNAL COMPLETION | |
| await websocket.send_json({ | |
| "type": "complete" | |
| }) | |
| except WebSocketDisconnect: | |
| print("Middleware disconnected") | |
| except Exception as e: | |
| print(f"Backend error: {e}") | |
| await websocket.send_json({ | |
| "type": "error", | |
| "message": "Backend error occurred" | |
| }) |