devchavda11 commited on
Commit
e4d90af
·
verified ·
1 Parent(s): aa57f37

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +89 -100
src/streamlit_app.py CHANGED
@@ -10,113 +10,103 @@ st.title("College Chatbot")
10
  TEMP_DIR = "/tmp"
11
  os.makedirs(TEMP_DIR, exist_ok=True)
12
 
13
-
14
  # --- Helpers ---
15
- def set_title(messages):
16
- if messages:
17
- title = "New Chat"
18
- st.session_state.chat_dict[st.session_state.current_chat_id] = title
19
-
20
 
21
  def set_config():
 
22
  return {"configurable": {"thread_id": st.session_state.current_chat_id}}
23
 
24
-
25
  def load_session_state():
 
26
  if "chats" not in st.session_state:
27
  st.session_state.chats = get_all_chat_ids()
28
  if "current_chat_id" not in st.session_state:
29
- if len(st.session_state.chats) > 0:
30
  st.session_state.current_chat_id = st.session_state.chats[-1]
31
  else:
 
32
  new_id = str(uuid.uuid4())
33
  st.session_state.chats.append(new_id)
34
  st.session_state.current_chat_id = new_id
 
 
 
 
 
35
  if "chat_dict" not in st.session_state:
36
- st.session_state.chat_dict = {}
37
-
38
 
39
  def render_sidebar():
 
40
  with st.sidebar:
41
  st.title("Chats")
42
  if st.button("➕ New Chat"):
43
  new_id = str(uuid.uuid4())
44
  st.session_state.chats.append(new_id)
45
  st.session_state.current_chat_id = new_id
46
- config = {"configurable": {"thread_id": new_id}}
47
- workflow.update_state(config, {"messages": [system]})
48
- st.session_state.chat_dict[new_id] = "New Chat"
 
 
 
 
49
 
 
50
  for chat_id in st.session_state.chats:
51
- if st.button(st.session_state.chat_dict.get(chat_id, "New Chat"), key=chat_id):
52
  st.session_state.current_chat_id = chat_id
 
53
 
54
-
55
- def create_download_link(file_path: str, label: str = None) -> str:
56
- """Generate HTML download link for a file."""
57
- if not os.path.exists(file_path):
58
- return ""
59
- try:
60
- with open(file_path, "rb") as f:
61
- data = f.read()
62
- b64 = base64.b64encode(data).decode()
63
- label = label or f"📥 Download {os.path.basename(file_path)}"
64
- href = f'<a href="data:file/octet-stream;base64,{b64}" download="{os.path.basename(file_path)}">{label}</a>'
65
- return href
66
- except Exception as e:
67
- return f"Error creating download link: {e}"
68
-
69
 
70
  def show_file(file_path: str):
71
- """Show file inline + download link."""
72
- if not os.path.exists(file_path):
73
- return
74
  ext = os.path.splitext(file_path)[1].lower()
75
  if ext in [".png", ".jpg", ".jpeg"]:
76
- st.image(file_path, caption=os.path.basename(file_path))
77
- elif ext in [".txt", ".py", ".java", ".cpp", ".md"]:
78
- with open(file_path, "r", encoding="utf-8", errors="ignore") as f:
79
- st.code(f.read(), language=ext.lstrip("."))
 
 
 
80
  st.markdown(create_download_link(file_path), unsafe_allow_html=True)
81
 
82
 
83
  def render_tool_message(tool_message: ToolMessage):
84
- """Render tool execution based on tool name instead of message content."""
85
- file_related_keywords = ["read", "write", "file", "save", "export", "create"]
86
  with st.chat_message("assistant"):
87
- tool_name = getattr(tool_message, "name", "").lower()
88
- st.info(f"🧰 Tool used: {tool_name or 'Unknown Tool'}")
89
-
90
- # Check if this is a file-related tool
91
- if any(k in tool_name for k in file_related_keywords):
92
- # Find all files in TEMP_DIR (freshly modified ones)
93
- created_files = sorted(
94
- [os.path.join(TEMP_DIR, f) for f in os.listdir(TEMP_DIR)],
95
- key=lambda x: os.path.getmtime(x),
96
- reverse=True,
97
- )
98
- if created_files:
99
- st.success("📄 File(s) created by tool:")
100
- for file_path in created_files[:3]: # show up to 3 recent
101
- show_file(file_path)
102
- else:
103
- st.warning("No new file detected in /tmp.")
104
- else:
105
- # Non-file tools
106
- if isinstance(tool_message.content, str):
107
- st.write(tool_message.content)
108
- elif isinstance(tool_message.content, dict):
109
- st.json(tool_message.content)
110
 
111
 
112
- def loadchats():
 
113
  if "current_chat_id" not in st.session_state:
114
- return []
115
- config = {"configurable": {"thread_id": st.session_state.current_chat_id}}
116
- state = workflow.get_state(config)
117
  messages = state.values.get("messages", [])
 
118
  for message in messages:
119
- if isinstance(message, HumanMessage):
 
 
 
120
  with st.chat_message("human"):
121
  st.write(message.content)
122
  elif isinstance(message, AIMessage):
@@ -124,42 +114,41 @@ def loadchats():
124
  st.write(message.content)
125
  elif isinstance(message, ToolMessage):
126
  render_tool_message(message)
127
- return messages
128
-
129
 
130
  # --- Main Chat Flow ---
131
  load_session_state()
132
  render_sidebar()
133
 
134
- if "current_chat_id" in st.session_state:
135
- loadchats()
136
- user_input = st.chat_input("Your message:")
137
- if user_input:
138
- with st.chat_message("human"):
139
- st.write(user_input)
140
-
141
- with st.chat_message("assistant"):
142
- # ... (inside with st.chat_message("assistant"))
143
- response_placeholder = st.empty()
144
- full_response = ""
145
-
146
- # Use stream_mode='values' for a cleaner loop
147
- for chunk in workflow.stream(
148
- {"messages": [HumanMessage(user_input)]}, # System message should already be in state
149
- config=set_config()):
150
- # The output of a streaming node is the key of the node itself
151
- if "llmresponse" in chunk:
152
- ai_message = chunk["llmresponse"]["messages"][-1]
153
- full_response += ai_message.content
154
- response_placeholder.markdown(full_response + "▌") # Add a cursor effect
155
-
156
- # Handle tool calls if they appear in the stream
157
- elif "tool_node" in chunk:
158
- # Since we are streaming, render the tool message right away
159
- # And clear the response placeholder for the next text chunk
160
- response_placeholder.empty()
161
- tool_messages = chunk["tool_node"]["messages"]
162
- for tool_message in tool_messages:
163
- render_tool_message(tool_message)
164
-
165
- response_placeholder.markdown(full_response) # Final update without cursor
 
 
10
  TEMP_DIR = "/tmp"
11
  os.makedirs(TEMP_DIR, exist_ok=True)
12
 
 
13
  # --- Helpers ---
 
 
 
 
 
14
 
15
  def set_config():
16
+ """Returns the config for the current chat thread."""
17
  return {"configurable": {"thread_id": st.session_state.current_chat_id}}
18
 
 
19
  def load_session_state():
20
+ """Initializes the session state for chats."""
21
  if "chats" not in st.session_state:
22
  st.session_state.chats = get_all_chat_ids()
23
  if "current_chat_id" not in st.session_state:
24
+ if st.session_state.chats:
25
  st.session_state.current_chat_id = st.session_state.chats[-1]
26
  else:
27
+ # Create a new chat if none exist
28
  new_id = str(uuid.uuid4())
29
  st.session_state.chats.append(new_id)
30
  st.session_state.current_chat_id = new_id
31
+ # Initialize the new chat with the system message
32
+ workflow.update_state(
33
+ {"configurable": {"thread_id": new_id}},
34
+ {"messages": [system]}
35
+ )
36
  if "chat_dict" not in st.session_state:
37
+ # Simple naming for chats for now
38
+ st.session_state.chat_dict = {chat_id: f"Chat {i+1}" for i, chat_id in enumerate(st.session_state.chats)}
39
 
40
  def render_sidebar():
41
+ """Displays the chat history and a new chat button in the sidebar."""
42
  with st.sidebar:
43
  st.title("Chats")
44
  if st.button("➕ New Chat"):
45
  new_id = str(uuid.uuid4())
46
  st.session_state.chats.append(new_id)
47
  st.session_state.current_chat_id = new_id
48
+ st.session_state.chat_dict[new_id] = f"Chat {len(st.session_state.chats)}"
49
+ # Initialize the new chat with the system message
50
+ workflow.update_state(
51
+ {"configurable": {"thread_id": new_id}},
52
+ {"messages": [system]}
53
+ )
54
+ st.rerun() # Rerun to switch to the new chat
55
 
56
+ st.divider()
57
  for chat_id in st.session_state.chats:
58
+ if st.button(st.session_state.chat_dict.get(chat_id, "New Chat"), key=chat_id, use_container_width=True):
59
  st.session_state.current_chat_id = chat_id
60
+ st.rerun()
61
 
62
+ def create_download_link(file_path: str) -> str:
63
+ """Generates an HTML download link for a file."""
64
+ if not os.path.exists(file_path): return ""
65
+ with open(file_path, "rb") as f:
66
+ data = f.read()
67
+ b64 = base64.b64encode(data).decode()
68
+ return f'<a href="data:file/octet-stream;base64,{b64}" download="{os.path.basename(file_path)}">📥 Download {os.path.basename(file_path)}</a>'
 
 
 
 
 
 
 
 
69
 
70
  def show_file(file_path: str):
71
+ """Displays a file inline with a download link."""
72
+ if not os.path.exists(file_path): return
 
73
  ext = os.path.splitext(file_path)[1].lower()
74
  if ext in [".png", ".jpg", ".jpeg"]:
75
+ st.image(file_path)
76
+ else:
77
+ try:
78
+ with open(file_path, "r", encoding="utf-8") as f:
79
+ st.code(f.read(), language=ext.lstrip("."))
80
+ except:
81
+ st.warning(f"Could not display file: {os.path.basename(file_path)}")
82
  st.markdown(create_download_link(file_path), unsafe_allow_html=True)
83
 
84
 
85
  def render_tool_message(tool_message: ToolMessage):
86
+ """Renders the output of a tool call."""
 
87
  with st.chat_message("assistant"):
88
+ st.info(f"🧰 Tool Used: `{tool_message.name}`")
89
+ # For file-creating tools, we can find and show the file
90
+ if tool_message.name in ["write_file", "plot_graph"] and "saved at" in str(tool_message.content):
91
+ file_path = str(tool_message.content).split("saved at")[-1].strip()
92
+ show_file(file_path)
93
+ else: # For other tools, just show the content
94
+ st.markdown(f"```\n{tool_message.content}\n```")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
 
97
+ def load_and_display_chats():
98
+ """Loads and displays the messages for the current chat."""
99
  if "current_chat_id" not in st.session_state:
100
+ return
101
+
102
+ state = workflow.get_state(set_config())
103
  messages = state.values.get("messages", [])
104
+
105
  for message in messages:
106
+ # Skip the system message in the UI
107
+ if isinstance(message, SystemMessage):
108
+ continue
109
+ elif isinstance(message, HumanMessage):
110
  with st.chat_message("human"):
111
  st.write(message.content)
112
  elif isinstance(message, AIMessage):
 
114
  st.write(message.content)
115
  elif isinstance(message, ToolMessage):
116
  render_tool_message(message)
 
 
117
 
118
  # --- Main Chat Flow ---
119
  load_session_state()
120
  render_sidebar()
121
 
122
+ load_and_display_chats()
123
+
124
+ if user_input := st.chat_input("Your message:"):
125
+ with st.chat_message("human"):
126
+ st.write(user_input)
127
+
128
+ with st.chat_message("assistant"):
129
+ response_placeholder = st.empty()
130
+ full_response = ""
131
+
132
+ # The input to stream is just the new message.
133
+ # LangGraph loads the history from the checkpointer.
134
+ input_for_graph = {"messages": [HumanMessage(content=user_input)]}
135
+
136
+ # Use stream_mode='values' to get outputs from each node as they run
137
+ for chunk in workflow.stream(input_for_graph, config=set_config()):
138
+ # Check for output from the LLM node
139
+ if "llmresponse" in chunk:
140
+ ai_message_chunk = chunk["llmresponse"]["messages"][-1]
141
+ full_response += ai_message_chunk.content
142
+ response_placeholder.markdown(full_response + "▌") # Add a cursor effect
143
+
144
+ # Check for output from the tool node
145
+ elif "tool_node" in chunk:
146
+ # A tool was called. We can render its output.
147
+ # First, clear the placeholder for the new text that will come
148
+ response_placeholder.empty()
149
+ tool_messages = chunk["tool_node"]["messages"]
150
+ for tool_message in tool_messages:
151
+ render_tool_message(tool_message)
152
+
153
+ # Final update to remove the cursor
154
+ response_placeholder.markdown(full_response)