File size: 7,447 Bytes
bc0299d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import streamlit as st
from agent import chatbot, classification_llm 
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage, SystemMessage
import uuid
import asyncio

def generate_thread_id():
    thread_id= uuid.uuid4()
    return thread_id

def reset_chat():
    thread_id=uuid.uuid4()
    st.session_state['thread_id']=thread_id
    add_thread(st.session_state['thread_id'])
    st.session_state['message_history']=[]

def add_thread(thread_id):
    if thread_id not in st.session_state['chat_threads']:
        st.session_state['chat_threads'].append(thread_id)
        st.session_state['thread_titles'][thread_id]=f"New Chat {len(st.session_state['chat_threads'])}"

def load_conversation(thread_id):
   try:
        state= chatbot.get_state(config={'configurable' : {'thread_id': thread_id}})
        raw_messages = state.values.get('messages', []) if state else []
        return [msg for msg in raw_messages if isinstance(msg, BaseMessage)]
   except Exception as e:
       print(f"Error loading conversation for thread {thread_id}: {e}")
       return []

def generate_title(query):
    print("--- Generating Title ---")
    try:
        prompt = f"Summarize this query into a very short title (max 5 words): {query}"
        response = classification_llm.invoke(prompt)
        title = response.content.strip().strip('"')
        return title if title else "Chat"
    except Exception as e:
        print(f"Error generating title: {e}")
        return "Chat"

if 'message_history' not in st.session_state: st.session_state['message_history']=[]
if 'thread_id' not in st.session_state: st.session_state['thread_id']=generate_thread_id()
if 'chat_threads' not in st.session_state: st.session_state['chat_threads']=[]
if 'thread_titles' not in st.session_state: st.session_state['thread_titles']={}
add_thread(st.session_state['thread_id'])

st.sidebar.title("IIITDMJ Chatbot")
if st.sidebar.button("βž• New Chat"):
    reset_chat()
    st.rerun()
st.sidebar.header("My Conversations")
for thread_id in st.session_state['chat_threads'][::-1]:
    title=st.session_state['thread_titles'].get(thread_id,"Untitled Chat")
    if st.sidebar.button(title, key=f"thread_{thread_id}", use_container_width=True):
        st.session_state['thread_id']=thread_id
        messages= load_conversation(thread_id)
        temp_messages = []
        for msg in messages:
            if isinstance(msg, SystemMessage): continue
            role = 'user' if isinstance(msg, HumanMessage) else 'assistant'
            temp_messages.append({'role': role, 'content': msg.content})
        st.session_state['message_history'] = temp_messages
        st.rerun()

st.title("IIITDMJ College Assistant")
st.caption("This bot uses a local vector store and LangGraph to answer your questions.")

for message in st.session_state['message_history']:
    with st.chat_message(message['role']):
        if message['role'] == 'assistant':
            st.markdown(f"<div style='font-size: 15px;'>{message['content']}</div>", unsafe_allow_html=True)
        else:
            st.markdown(message['content'])

user_input=st.chat_input("Ask about IIITDMJ...")

if user_input:

    CONFIG={'configurable' : {'thread_id': st.session_state['thread_id']}}

    st.session_state['message_history'].append({'role':'user','content':user_input})
    with st.chat_message('user'):
        st.markdown(user_input)

    with st.chat_message('assistant'):
        placeholder = st.empty()
        ai_message_content = ""

        try:
            print(f"\n--- Streaming response for Thread ID: {st.session_state['thread_id']} ---")

            async def stream_agent_events(stream_placeholder):
                local_ai_message_content_streamed = ""
                local_final_node_output = None
                local_final_node_name = ""

                async for event in chatbot.astream_events(
                    {'messages': [HumanMessage(content=user_input)]},
                    config=CONFIG,
                    version="v1"
                ):
                    kind = event["event"]
                    name = event["name"]

                    if kind == "on_chat_model_stream":
                        if name in ("generate_answer", "generate_synthesized_answer", "handle_chat"):
                            chunk_content = event["data"]["chunk"].content
                            if chunk_content:
                                local_ai_message_content_streamed += chunk_content
                                stream_placeholder.markdown(f"<div style='font-size: 15px;'>{local_ai_message_content_streamed}β–Œ</div>", unsafe_allow_html=True)

                    if kind == "on_chain_end":
                        if name in ("generate_answer", "generate_synthesized_answer", "handle_chat"):
                             if "output" in event.get("data", {}) and isinstance(event["data"]["output"], dict):
                                local_final_node_output = event["data"]["output"]
                                local_final_node_name = name
                                print(f"--- Captured final output from node: {name} ---")
                
                return local_ai_message_content_streamed, local_final_node_output, local_final_node_name


            streamed_content, final_output, final_name = asyncio.run(stream_agent_events(placeholder))

            if not streamed_content and final_output:
                print(f"--- Using fallback: No stream content captured. Using final output from {final_name}. ---")
                if "messages" in final_output and final_output["messages"]:
                    ai_message_content = final_output["messages"][-1].content
                    placeholder.markdown(f"<div style='font-size: 15px;'>{ai_message_content}</div>", unsafe_allow_html=True)
                else:
                    print(f"--- Fallback failed: Final output from {final_name} had unexpected format: {final_output} ---")
                    ai_message_content = "Sorry, I couldn't generate a response (fallback error)."
                    placeholder.markdown(ai_message_content) 

            elif streamed_content:
                 ai_message_content = streamed_content
                 placeholder.markdown(f"<div style='font-size: 15px;'>{ai_message_content}</div>", unsafe_allow_html=True)
            else:
                 print("--- Fallback failed: No stream content and no final output captured. ---")
                 ai_message_content = "Sorry, I couldn't generate a response (capture error)."
                 placeholder.markdown(ai_message_content)

        except Exception as e:
            st.error(f"An error occurred: {e}")
            print(f"ERROR DURING STREAM/FALLBACK: {e}")
            ai_message_content = "Sorry, I encountered an error during execution."
            placeholder.markdown(ai_message_content)

    if not ai_message_content:
        ai_message_content = "Sorry, I couldn't generate a response."

    st.session_state['message_history'].append({'role':'assistant','content':ai_message_content})

    current_id=st.session_state['thread_id']
    current_title=st.session_state['thread_titles'].get(current_id,"New Chat")
    if current_title.startswith("New Chat") and len(st.session_state['message_history']) <= 2:
        summarized_title = generate_title(user_input)
        st.session_state['thread_titles'][current_id] = summarized_title
        st.rerun()