IIITDMJ_Chatbot / app.py
PrinceDeepakSiddharth12's picture
Upload 5 files
bc0299d verified
import streamlit as st
from agent import chatbot, classification_llm
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage, SystemMessage
import uuid
import asyncio
def generate_thread_id():
thread_id= uuid.uuid4()
return thread_id
def reset_chat():
thread_id=uuid.uuid4()
st.session_state['thread_id']=thread_id
add_thread(st.session_state['thread_id'])
st.session_state['message_history']=[]
def add_thread(thread_id):
if thread_id not in st.session_state['chat_threads']:
st.session_state['chat_threads'].append(thread_id)
st.session_state['thread_titles'][thread_id]=f"New Chat {len(st.session_state['chat_threads'])}"
def load_conversation(thread_id):
try:
state= chatbot.get_state(config={'configurable' : {'thread_id': thread_id}})
raw_messages = state.values.get('messages', []) if state else []
return [msg for msg in raw_messages if isinstance(msg, BaseMessage)]
except Exception as e:
print(f"Error loading conversation for thread {thread_id}: {e}")
return []
def generate_title(query):
print("--- Generating Title ---")
try:
prompt = f"Summarize this query into a very short title (max 5 words): {query}"
response = classification_llm.invoke(prompt)
title = response.content.strip().strip('"')
return title if title else "Chat"
except Exception as e:
print(f"Error generating title: {e}")
return "Chat"
if 'message_history' not in st.session_state: st.session_state['message_history']=[]
if 'thread_id' not in st.session_state: st.session_state['thread_id']=generate_thread_id()
if 'chat_threads' not in st.session_state: st.session_state['chat_threads']=[]
if 'thread_titles' not in st.session_state: st.session_state['thread_titles']={}
add_thread(st.session_state['thread_id'])
st.sidebar.title("IIITDMJ Chatbot")
if st.sidebar.button("βž• New Chat"):
reset_chat()
st.rerun()
st.sidebar.header("My Conversations")
for thread_id in st.session_state['chat_threads'][::-1]:
title=st.session_state['thread_titles'].get(thread_id,"Untitled Chat")
if st.sidebar.button(title, key=f"thread_{thread_id}", use_container_width=True):
st.session_state['thread_id']=thread_id
messages= load_conversation(thread_id)
temp_messages = []
for msg in messages:
if isinstance(msg, SystemMessage): continue
role = 'user' if isinstance(msg, HumanMessage) else 'assistant'
temp_messages.append({'role': role, 'content': msg.content})
st.session_state['message_history'] = temp_messages
st.rerun()
st.title("IIITDMJ College Assistant")
st.caption("This bot uses a local vector store and LangGraph to answer your questions.")
for message in st.session_state['message_history']:
with st.chat_message(message['role']):
if message['role'] == 'assistant':
st.markdown(f"<div style='font-size: 15px;'>{message['content']}</div>", unsafe_allow_html=True)
else:
st.markdown(message['content'])
user_input=st.chat_input("Ask about IIITDMJ...")
if user_input:
CONFIG={'configurable' : {'thread_id': st.session_state['thread_id']}}
st.session_state['message_history'].append({'role':'user','content':user_input})
with st.chat_message('user'):
st.markdown(user_input)
with st.chat_message('assistant'):
placeholder = st.empty()
ai_message_content = ""
try:
print(f"\n--- Streaming response for Thread ID: {st.session_state['thread_id']} ---")
async def stream_agent_events(stream_placeholder):
local_ai_message_content_streamed = ""
local_final_node_output = None
local_final_node_name = ""
async for event in chatbot.astream_events(
{'messages': [HumanMessage(content=user_input)]},
config=CONFIG,
version="v1"
):
kind = event["event"]
name = event["name"]
if kind == "on_chat_model_stream":
if name in ("generate_answer", "generate_synthesized_answer", "handle_chat"):
chunk_content = event["data"]["chunk"].content
if chunk_content:
local_ai_message_content_streamed += chunk_content
stream_placeholder.markdown(f"<div style='font-size: 15px;'>{local_ai_message_content_streamed}β–Œ</div>", unsafe_allow_html=True)
if kind == "on_chain_end":
if name in ("generate_answer", "generate_synthesized_answer", "handle_chat"):
if "output" in event.get("data", {}) and isinstance(event["data"]["output"], dict):
local_final_node_output = event["data"]["output"]
local_final_node_name = name
print(f"--- Captured final output from node: {name} ---")
return local_ai_message_content_streamed, local_final_node_output, local_final_node_name
streamed_content, final_output, final_name = asyncio.run(stream_agent_events(placeholder))
if not streamed_content and final_output:
print(f"--- Using fallback: No stream content captured. Using final output from {final_name}. ---")
if "messages" in final_output and final_output["messages"]:
ai_message_content = final_output["messages"][-1].content
placeholder.markdown(f"<div style='font-size: 15px;'>{ai_message_content}</div>", unsafe_allow_html=True)
else:
print(f"--- Fallback failed: Final output from {final_name} had unexpected format: {final_output} ---")
ai_message_content = "Sorry, I couldn't generate a response (fallback error)."
placeholder.markdown(ai_message_content)
elif streamed_content:
ai_message_content = streamed_content
placeholder.markdown(f"<div style='font-size: 15px;'>{ai_message_content}</div>", unsafe_allow_html=True)
else:
print("--- Fallback failed: No stream content and no final output captured. ---")
ai_message_content = "Sorry, I couldn't generate a response (capture error)."
placeholder.markdown(ai_message_content)
except Exception as e:
st.error(f"An error occurred: {e}")
print(f"ERROR DURING STREAM/FALLBACK: {e}")
ai_message_content = "Sorry, I encountered an error during execution."
placeholder.markdown(ai_message_content)
if not ai_message_content:
ai_message_content = "Sorry, I couldn't generate a response."
st.session_state['message_history'].append({'role':'assistant','content':ai_message_content})
current_id=st.session_state['thread_id']
current_title=st.session_state['thread_titles'].get(current_id,"New Chat")
if current_title.startswith("New Chat") and len(st.session_state['message_history']) <= 2:
summarized_title = generate_title(user_input)
st.session_state['thread_titles'][current_id] = summarized_title
st.rerun()