File size: 7,447 Bytes
bc0299d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import streamlit as st
from agent import chatbot, classification_llm
from langchain_core.messages import HumanMessage, AIMessage, BaseMessage, SystemMessage
import uuid
import asyncio
def generate_thread_id():
thread_id= uuid.uuid4()
return thread_id
def reset_chat():
thread_id=uuid.uuid4()
st.session_state['thread_id']=thread_id
add_thread(st.session_state['thread_id'])
st.session_state['message_history']=[]
def add_thread(thread_id):
if thread_id not in st.session_state['chat_threads']:
st.session_state['chat_threads'].append(thread_id)
st.session_state['thread_titles'][thread_id]=f"New Chat {len(st.session_state['chat_threads'])}"
def load_conversation(thread_id):
try:
state= chatbot.get_state(config={'configurable' : {'thread_id': thread_id}})
raw_messages = state.values.get('messages', []) if state else []
return [msg for msg in raw_messages if isinstance(msg, BaseMessage)]
except Exception as e:
print(f"Error loading conversation for thread {thread_id}: {e}")
return []
def generate_title(query):
print("--- Generating Title ---")
try:
prompt = f"Summarize this query into a very short title (max 5 words): {query}"
response = classification_llm.invoke(prompt)
title = response.content.strip().strip('"')
return title if title else "Chat"
except Exception as e:
print(f"Error generating title: {e}")
return "Chat"
if 'message_history' not in st.session_state: st.session_state['message_history']=[]
if 'thread_id' not in st.session_state: st.session_state['thread_id']=generate_thread_id()
if 'chat_threads' not in st.session_state: st.session_state['chat_threads']=[]
if 'thread_titles' not in st.session_state: st.session_state['thread_titles']={}
add_thread(st.session_state['thread_id'])
st.sidebar.title("IIITDMJ Chatbot")
if st.sidebar.button("β New Chat"):
reset_chat()
st.rerun()
st.sidebar.header("My Conversations")
for thread_id in st.session_state['chat_threads'][::-1]:
title=st.session_state['thread_titles'].get(thread_id,"Untitled Chat")
if st.sidebar.button(title, key=f"thread_{thread_id}", use_container_width=True):
st.session_state['thread_id']=thread_id
messages= load_conversation(thread_id)
temp_messages = []
for msg in messages:
if isinstance(msg, SystemMessage): continue
role = 'user' if isinstance(msg, HumanMessage) else 'assistant'
temp_messages.append({'role': role, 'content': msg.content})
st.session_state['message_history'] = temp_messages
st.rerun()
st.title("IIITDMJ College Assistant")
st.caption("This bot uses a local vector store and LangGraph to answer your questions.")
for message in st.session_state['message_history']:
with st.chat_message(message['role']):
if message['role'] == 'assistant':
st.markdown(f"<div style='font-size: 15px;'>{message['content']}</div>", unsafe_allow_html=True)
else:
st.markdown(message['content'])
user_input=st.chat_input("Ask about IIITDMJ...")
if user_input:
CONFIG={'configurable' : {'thread_id': st.session_state['thread_id']}}
st.session_state['message_history'].append({'role':'user','content':user_input})
with st.chat_message('user'):
st.markdown(user_input)
with st.chat_message('assistant'):
placeholder = st.empty()
ai_message_content = ""
try:
print(f"\n--- Streaming response for Thread ID: {st.session_state['thread_id']} ---")
async def stream_agent_events(stream_placeholder):
local_ai_message_content_streamed = ""
local_final_node_output = None
local_final_node_name = ""
async for event in chatbot.astream_events(
{'messages': [HumanMessage(content=user_input)]},
config=CONFIG,
version="v1"
):
kind = event["event"]
name = event["name"]
if kind == "on_chat_model_stream":
if name in ("generate_answer", "generate_synthesized_answer", "handle_chat"):
chunk_content = event["data"]["chunk"].content
if chunk_content:
local_ai_message_content_streamed += chunk_content
stream_placeholder.markdown(f"<div style='font-size: 15px;'>{local_ai_message_content_streamed}β</div>", unsafe_allow_html=True)
if kind == "on_chain_end":
if name in ("generate_answer", "generate_synthesized_answer", "handle_chat"):
if "output" in event.get("data", {}) and isinstance(event["data"]["output"], dict):
local_final_node_output = event["data"]["output"]
local_final_node_name = name
print(f"--- Captured final output from node: {name} ---")
return local_ai_message_content_streamed, local_final_node_output, local_final_node_name
streamed_content, final_output, final_name = asyncio.run(stream_agent_events(placeholder))
if not streamed_content and final_output:
print(f"--- Using fallback: No stream content captured. Using final output from {final_name}. ---")
if "messages" in final_output and final_output["messages"]:
ai_message_content = final_output["messages"][-1].content
placeholder.markdown(f"<div style='font-size: 15px;'>{ai_message_content}</div>", unsafe_allow_html=True)
else:
print(f"--- Fallback failed: Final output from {final_name} had unexpected format: {final_output} ---")
ai_message_content = "Sorry, I couldn't generate a response (fallback error)."
placeholder.markdown(ai_message_content)
elif streamed_content:
ai_message_content = streamed_content
placeholder.markdown(f"<div style='font-size: 15px;'>{ai_message_content}</div>", unsafe_allow_html=True)
else:
print("--- Fallback failed: No stream content and no final output captured. ---")
ai_message_content = "Sorry, I couldn't generate a response (capture error)."
placeholder.markdown(ai_message_content)
except Exception as e:
st.error(f"An error occurred: {e}")
print(f"ERROR DURING STREAM/FALLBACK: {e}")
ai_message_content = "Sorry, I encountered an error during execution."
placeholder.markdown(ai_message_content)
if not ai_message_content:
ai_message_content = "Sorry, I couldn't generate a response."
st.session_state['message_history'].append({'role':'assistant','content':ai_message_content})
current_id=st.session_state['thread_id']
current_title=st.session_state['thread_titles'].get(current_id,"New Chat")
if current_title.startswith("New Chat") and len(st.session_state['message_history']) <= 2:
summarized_title = generate_title(user_input)
st.session_state['thread_titles'][current_id] = summarized_title
st.rerun() |