devchavda11 commited on
Commit
7791c9a
·
verified ·
1 Parent(s): ad95b8d

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +92 -90
src/streamlit_app.py CHANGED
@@ -1,5 +1,5 @@
1
  import streamlit as st
2
- from chat_langraph import system, workflow, HumanMessage, AIMessage, get_all_chat_ids, ToolMessage , SystemMessage
3
  import uuid
4
  import os
5
  import base64
@@ -10,103 +10,114 @@ st.title("College Chatbot")
10
  TEMP_DIR = "/tmp"
11
  os.makedirs(TEMP_DIR, exist_ok=True)
12
 
 
13
  # --- Helpers ---
14
 
 
 
 
 
 
 
15
  def set_config():
16
- """Returns the config for the current chat thread."""
17
  return {"configurable": {"thread_id": st.session_state.current_chat_id}}
18
 
 
19
  def load_session_state():
20
- """Initializes the session state for chats."""
21
  if "chats" not in st.session_state:
22
  st.session_state.chats = get_all_chat_ids()
23
  if "current_chat_id" not in st.session_state:
24
- if st.session_state.chats:
25
  st.session_state.current_chat_id = st.session_state.chats[-1]
26
  else:
27
- # Create a new chat if none exist
28
  new_id = str(uuid.uuid4())
29
  st.session_state.chats.append(new_id)
30
  st.session_state.current_chat_id = new_id
31
- # Initialize the new chat with the system message
32
- workflow.update_state(
33
- {"configurable": {"thread_id": new_id}},
34
- {"messages": [system]}
35
- )
36
  if "chat_dict" not in st.session_state:
37
- # Simple naming for chats for now
38
- st.session_state.chat_dict = {chat_id: f"Chat {i+1}" for i, chat_id in enumerate(st.session_state.chats)}
39
 
40
  def render_sidebar():
41
- """Displays the chat history and a new chat button in the sidebar."""
42
  with st.sidebar:
43
  st.title("Chats")
44
  if st.button("➕ New Chat"):
45
  new_id = str(uuid.uuid4())
46
  st.session_state.chats.append(new_id)
47
  st.session_state.current_chat_id = new_id
48
- st.session_state.chat_dict[new_id] = f"Chat {len(st.session_state.chats)}"
49
- # Initialize the new chat with the system message
50
- workflow.update_state(
51
- {"configurable": {"thread_id": new_id}},
52
- {"messages": [system]}
53
- )
54
- st.rerun() # Rerun to switch to the new chat
55
 
56
- st.divider()
57
  for chat_id in st.session_state.chats:
58
- if st.button(st.session_state.chat_dict.get(chat_id, "New Chat"), key=chat_id, use_container_width=True):
59
  st.session_state.current_chat_id = chat_id
60
- st.rerun()
61
 
62
- def create_download_link(file_path: str) -> str:
63
- """Generates an HTML download link for a file."""
64
- if not os.path.exists(file_path): return ""
65
- with open(file_path, "rb") as f:
66
- data = f.read()
67
- b64 = base64.b64encode(data).decode()
68
- return f'<a href="data:file/octet-stream;base64,{b64}" download="{os.path.basename(file_path)}">📥 Download {os.path.basename(file_path)}</a>'
 
 
 
 
 
 
 
 
69
 
70
  def show_file(file_path: str):
71
- """Displays a file inline with a download link."""
72
- if not os.path.exists(file_path): return
 
73
  ext = os.path.splitext(file_path)[1].lower()
74
  if ext in [".png", ".jpg", ".jpeg"]:
75
- st.image(file_path)
76
- else:
77
- try:
78
- with open(file_path, "r", encoding="utf-8") as f:
79
- st.code(f.read(), language=ext.lstrip("."))
80
- except:
81
- st.warning(f"Could not display file: {os.path.basename(file_path)}")
82
  st.markdown(create_download_link(file_path), unsafe_allow_html=True)
83
 
84
 
85
  def render_tool_message(tool_message: ToolMessage):
86
- """Renders the output of a tool call."""
 
87
  with st.chat_message("assistant"):
88
- st.info(f"🧰 Tool Used: `{tool_message.name}`")
89
- # For file-creating tools, we can find and show the file
90
- if tool_message.name in ["write_file", "plot_graph"] and "saved at" in str(tool_message.content):
91
- file_path = str(tool_message.content).split("saved at")[-1].strip()
92
- show_file(file_path)
93
- else: # For other tools, just show the content
94
- st.markdown(f"```\n{tool_message.content}\n```")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
 
97
- def load_and_display_chats():
98
- """Loads and displays the messages for the current chat."""
99
  if "current_chat_id" not in st.session_state:
100
- return
101
-
102
- state = workflow.get_state(set_config())
103
  messages = state.values.get("messages", [])
104
-
105
  for message in messages:
106
- # Skip the system message in the UI
107
- if isinstance(message, SystemMessage):
108
- continue
109
- elif isinstance(message, HumanMessage):
110
  with st.chat_message("human"):
111
  st.write(message.content)
112
  elif isinstance(message, AIMessage):
@@ -114,41 +125,32 @@ def load_and_display_chats():
114
  st.write(message.content)
115
  elif isinstance(message, ToolMessage):
116
  render_tool_message(message)
 
 
117
 
118
  # --- Main Chat Flow ---
 
119
  load_session_state()
120
  render_sidebar()
121
 
122
- load_and_display_chats()
123
-
124
- if user_input := st.chat_input("Your message:"):
125
- with st.chat_message("human"):
126
- st.write(user_input)
127
-
128
- with st.chat_message("assistant"):
129
- response_placeholder = st.empty()
130
- full_response = ""
131
-
132
- # The input to stream is just the new message.
133
- # LangGraph loads the history from the checkpointer.
134
- input_for_graph = {"messages": [HumanMessage(content=user_input)]}
135
-
136
- # Use stream_mode='values' to get outputs from each node as they run
137
- for chunk in workflow.stream(input_for_graph, config=set_config() , stream_mode = "messages"):
138
- # Check for output from the LLM node
139
- if "llmresponse" in chunk:
140
- ai_message_chunk = chunk["llmresponse"]["messages"][-1]
141
- full_response += ai_message_chunk.content
142
- response_placeholder.markdown(full_response + "▌") # Add a cursor effect
143
-
144
- # Check for output from the tool node
145
- elif "tool_node" in chunk:
146
- # A tool was called. We can render its output.
147
- # First, clear the placeholder for the new text that will come
148
- response_placeholder.empty()
149
- tool_messages = chunk["tool_node"]["messages"]
150
- for tool_message in tool_messages:
151
- render_tool_message(tool_message)
152
-
153
- # Final update to remove the cursor
154
- response_placeholder.markdown(full_response)
 
1
  import streamlit as st
2
+ from chat_langraph import system, workflow, HumanMessage, AIMessage, get_all_chat_ids, ToolMessage
3
  import uuid
4
  import os
5
  import base64
 
10
  TEMP_DIR = "/tmp"
11
  os.makedirs(TEMP_DIR, exist_ok=True)
12
 
13
+
14
  # --- Helpers ---
15
 
16
+ def set_title(messages):
17
+ if messages:
18
+ title = "New Chat"
19
+ st.session_state.chat_dict[st.session_state.current_chat_id] = title
20
+
21
+
22
  def set_config():
 
23
  return {"configurable": {"thread_id": st.session_state.current_chat_id}}
24
 
25
+
26
  def load_session_state():
 
27
  if "chats" not in st.session_state:
28
  st.session_state.chats = get_all_chat_ids()
29
  if "current_chat_id" not in st.session_state:
30
+ if len(st.session_state.chats) > 0:
31
  st.session_state.current_chat_id = st.session_state.chats[-1]
32
  else:
 
33
  new_id = str(uuid.uuid4())
34
  st.session_state.chats.append(new_id)
35
  st.session_state.current_chat_id = new_id
 
 
 
 
 
36
  if "chat_dict" not in st.session_state:
37
+ st.session_state.chat_dict = {}
38
+
39
 
40
  def render_sidebar():
 
41
  with st.sidebar:
42
  st.title("Chats")
43
  if st.button("➕ New Chat"):
44
  new_id = str(uuid.uuid4())
45
  st.session_state.chats.append(new_id)
46
  st.session_state.current_chat_id = new_id
47
+ config = {"configurable": {"thread_id": new_id}}
48
+ workflow.update_state(config, {"messages": [system]})
49
+ st.session_state.chat_dict[new_id] = "New Chat"
 
 
 
 
50
 
 
51
  for chat_id in st.session_state.chats:
52
+ if st.button(st.session_state.chat_dict.get(chat_id, "New Chat"), key=chat_id):
53
  st.session_state.current_chat_id = chat_id
 
54
 
55
+
56
+ def create_download_link(file_path: str, label: str = None) -> str:
57
+ """Generate HTML download link for a file."""
58
+ if not os.path.exists(file_path):
59
+ return ""
60
+ try:
61
+ with open(file_path, "rb") as f:
62
+ data = f.read()
63
+ b64 = base64.b64encode(data).decode()
64
+ label = label or f"📥 Download {os.path.basename(file_path)}"
65
+ href = f'<a href="data:file/octet-stream;base64,{b64}" download="{os.path.basename(file_path)}">{label}</a>'
66
+ return href
67
+ except Exception as e:
68
+ return f"Error creating download link: {e}"
69
+
70
 
71
  def show_file(file_path: str):
72
+ """Show file inline + download link."""
73
+ if not os.path.exists(file_path):
74
+ return
75
  ext = os.path.splitext(file_path)[1].lower()
76
  if ext in [".png", ".jpg", ".jpeg"]:
77
+ st.image(file_path, caption=os.path.basename(file_path))
78
+ elif ext in [".txt", ".py", ".java", ".cpp", ".md"]:
79
+ with open(file_path, "r", encoding="utf-8", errors="ignore") as f:
80
+ st.code(f.read(), language=ext.lstrip("."))
 
 
 
81
  st.markdown(create_download_link(file_path), unsafe_allow_html=True)
82
 
83
 
84
  def render_tool_message(tool_message: ToolMessage):
85
+ """Render tool execution based on tool name instead of message content."""
86
+ file_related_keywords = ["read", "write", "file", "save", "export", "create"]
87
  with st.chat_message("assistant"):
88
+ tool_name = getattr(tool_message, "name", "").lower()
89
+ st.info(f"🧰 Tool used: {tool_name or 'Unknown Tool'}")
90
+
91
+ # Check if this is a file-related tool
92
+ if any(k in tool_name for k in file_related_keywords):
93
+ # Find all files in TEMP_DIR (freshly modified ones)
94
+ created_files = sorted(
95
+ [os.path.join(TEMP_DIR, f) for f in os.listdir(TEMP_DIR)],
96
+ key=lambda x: os.path.getmtime(x),
97
+ reverse=True,
98
+ )
99
+ if created_files:
100
+ st.success("📄 File(s) created by tool:")
101
+ for file_path in created_files[:3]: # show up to 3 recent
102
+ show_file(file_path)
103
+ else:
104
+ st.warning("No new file detected in /tmp.")
105
+ else:
106
+ # Non-file tools
107
+ if isinstance(tool_message.content, str):
108
+ st.write(tool_message.content)
109
+ elif isinstance(tool_message.content, dict):
110
+ st.json(tool_message.content)
111
 
112
 
113
+ def loadchats():
 
114
  if "current_chat_id" not in st.session_state:
115
+ return []
116
+ config = {"configurable": {"thread_id": st.session_state.current_chat_id}}
117
+ state = workflow.get_state(config)
118
  messages = state.values.get("messages", [])
 
119
  for message in messages:
120
+ if isinstance(message, HumanMessage):
 
 
 
121
  with st.chat_message("human"):
122
  st.write(message.content)
123
  elif isinstance(message, AIMessage):
 
125
  st.write(message.content)
126
  elif isinstance(message, ToolMessage):
127
  render_tool_message(message)
128
+ return messages
129
+
130
 
131
  # --- Main Chat Flow ---
132
+
133
  load_session_state()
134
  render_sidebar()
135
 
136
+ if "current_chat_id" in st.session_state:
137
+ loadchats()
138
+ if user_input := st.chat_input("Your message:"):
139
+ with st.chat_message("human"):
140
+ st.write(user_input)
141
+
142
+ with st.chat_message("assistant"):
143
+ response_placeholder = st.empty()
144
+ full_response = ""
145
+
146
+ for message, metadata in workflow.stream(
147
+ {"messages": [system, HumanMessage(user_input)]},
148
+ config={"configurable": {"thread_id": st.session_state.current_chat_id}},
149
+ stream_mode="messages",
150
+ ):
151
+ if isinstance(message, AIMessage):
152
+ full_response += message.content or ""
153
+ elif isinstance(message, ToolMessage):
154
+ render_tool_message(message)
155
+
156
+ response_placeholder.markdown(full_response)