File size: 3,423 Bytes
c9ad3d0
10ac869
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40f3d83
10ac869
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62c19d5
10ac869
 
 
 
 
 
62c19d5
10ac869
 
 
 
 
 
 
 
efec5af
10ac869
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62c19d5
10ac869
 
 
 
 
 
c9ad3d0
 
 
 
10ac869
62c19d5
c9ad3d0
 
 
10ac869
 
 
c9ad3d0
 
 
 
 
efec5af
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import json
import os
import re
import sqlite3
from typing import TypedDict, Annotated

from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from dotenv import load_dotenv

from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.messages import (
    HumanMessage,
    BaseMessage,
    SystemMessage,
)
from langgraph.graph.message import add_messages
from langgraph.graph import StateGraph, START, END
from langgraph.checkpoint.sqlite import SqliteSaver

# RAG imports
from rag import load_student_documents, create_vectorstore, get_retriever

load_dotenv(override=True)

# ------------------ RAG SETUP ------------------
documents = load_student_documents("studentDataset.csv")
vectorstore = create_vectorstore(documents)
retriever = get_retriever(vectorstore)

# ------------------ LLM ------------------
llm = ChatGoogleGenerativeAI(
    model="models/gemini-2.5-flash",
    temperature=0.3,
)

# ------------------ STATE ------------------
class ChatState(TypedDict):
    messages: Annotated[list[BaseMessage], add_messages]

# ------------------ GRAPH NODE ------------------
def chat_node(state: ChatState):
    user_msg = state["messages"][-1].content

    # Exact register number search
    match = re.search(r"\b\d{4,}\b", user_msg)
    docs = []
    if match:
        reg_no = match.group()
        docs = [d for d in documents if reg_no in d.page_content]

    # Vector fallback
    if not docs:
        docs = retriever.invoke(user_msg)

    context = "\n\n".join(d.page_content for d in docs)

    system_prompt = f"""
    You are a student database assistant.
    Answer ONLY using the information below.
    If the answer is not present, say "I don't have that information".
    DATA:
    {context}
    """

    response = llm.invoke([
        SystemMessage(content=system_prompt),
        HumanMessage(content=user_msg),
    ])

    return {"messages": [response]}

# ------------------ GRAPH ------------------
graph = StateGraph(ChatState)
graph.add_node("chat_node", chat_node)
graph.add_edge(START, "chat_node")
graph.add_edge("chat_node", END)

conn = sqlite3.connect("a.db", check_same_thread=False)
checkpointer = SqliteSaver(conn=conn)

workflow = graph.compile(checkpointer=checkpointer)

# ------------------ FASTAPI ------------------
app = FastAPI()

@app.websocket("/chat")
async def chat_ws(websocket: WebSocket):
    await websocket.accept()

    try:
        session_id = websocket.query_params.get("session_id", "default")

        while True:
            user_text = await websocket.receive_text()

            config = {"configurable": {"thread_id": session_id}}

            # STREAM RESPONSE
            for msg, _ in workflow.stream(
                {"messages": [HumanMessage(user_text)]},
                config=config,
                stream_mode="messages",
            ):
                if msg.content:
                    await websocket.send_json({
                        "type": "response",
                        "content": msg.content
                    })

            # SIGNAL COMPLETION
            await websocket.send_json({
                "type": "complete"
            })

    except WebSocketDisconnect:
        print("Middleware disconnected")
    except Exception as e:
        print(f"Backend error: {e}")
        await websocket.send_json({
            "type": "error",
            "message": "Backend error occurred"
        })