AbenzaFran commited on
Commit
b3ee356
·
verified ·
1 Parent(s): e3a58bf

Now with streaming

Browse files
Files changed (1) hide show
  1. app.py +313 -107
app.py CHANGED
@@ -1,116 +1,322 @@
1
  import os
2
  import re
 
 
 
 
 
 
 
 
3
  import streamlit as st
4
- import openai
5
  from dotenv import load_dotenv
6
- from langchain.agents.openai_assistant import OpenAIAssistantRunnable
7
-
8
- # Load environment variables
9
- load_dotenv()
10
- api_key = os.getenv("OPENAI_API_KEY")
11
- extractor_agent = os.getenv("ASSISTANT_ID_SOLUTION_SPECIFIER_A")
12
-
13
- # Create the assistant
14
- extractor_llm = OpenAIAssistantRunnable(
15
- assistant_id=extractor_agent,
16
- api_key=api_key,
17
- as_agent=True
18
- )
19
-
20
- def remove_citation(text: str) -> str:
21
- pattern = r"【\d+†\w+】"
22
- return re.sub(pattern, "📚", text)
23
-
24
- # Initialize session state
25
- if "messages" not in st.session_state:
26
- st.session_state["messages"] = []
27
- if "thread_id" not in st.session_state:
28
- st.session_state["thread_id"] = None
29
- # A flag to indicate if a request is in progress
30
- if "is_in_request" not in st.session_state:
31
- st.session_state["is_in_request"] = False
32
-
33
- st.title("inteliventa Argumentador Inversionistas")
34
-
35
- def predict(user_input: str) -> str:
36
- """
37
- This function calls our OpenAIAssistantRunnable to get a response.
38
- If st.session_state["thread_id"] is None, we start a new thread.
39
- Otherwise, we continue the existing thread.
40
-
41
- If a concurrency error occurs ("Can't add messages to thread..."), we reset
42
- the thread_id and try again once on a fresh thread.
43
- """
44
- try:
45
- if st.session_state["thread_id"] is None:
46
- # Start a new thread
47
- response = extractor_llm.invoke({"content": user_input})
48
- st.session_state["thread_id"] = response.thread_id
49
- else:
50
- # Continue existing thread
51
- response = extractor_llm.invoke(
52
- {"content": user_input, "thread_id": st.session_state["thread_id"]}
 
 
 
 
 
 
53
  )
 
 
54
 
55
- output = response.return_values["output"]
56
- return remove_citation(output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
- except openai.error.BadRequestError as e:
59
- # If we get the specific concurrency error, reset thread and try once more
60
- if "while a run" in str(e):
61
- st.session_state["thread_id"] = None
62
- # Now create a new thread for the same user input
63
- try:
64
- response = extractor_llm.invoke({"content": user_input})
65
- st.session_state["thread_id"] = response.thread_id
66
- output = response.return_values["output"]
67
- return remove_citation(output)
68
- except Exception as e2:
69
- st.error(f"Error after resetting thread: {e2}")
70
- return ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  else:
72
- # Some other 400 error
73
- st.error(str(e))
74
- return ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  except Exception as e:
76
- st.error(str(e))
77
- return ""
78
-
79
- # Display any existing messages
80
- for msg in st.session_state["messages"]:
81
- if msg["role"] == "user":
82
- with st.chat_message("user"):
83
- st.write(msg["content"])
84
- else:
85
- with st.chat_message("assistant"):
86
- st.write(msg["content"])
87
-
88
- # Chat input at the bottom of the page
89
- user_input = st.chat_input("Type your message here...")
90
-
91
- # Process the user input only if:
92
- # 1) There is some text, and
93
- # 2) We are not already handling a request (is_in_request == False)
94
- if user_input and not st.session_state["is_in_request"]:
95
- # Lock to prevent duplicate requests
96
- st.session_state["is_in_request"] = True
97
-
98
- # Add the user message to session state
99
- st.session_state["messages"].append({"role": "user", "content": user_input})
100
-
101
- # Display the user's message
102
- with st.chat_message("user"):
103
- st.write(user_input)
104
-
105
- # Get assistant response
106
- response_text = predict(user_input)
107
-
108
- # Add assistant response to session state
109
- st.session_state["messages"].append({"role": "assistant", "content": response_text})
110
-
111
- # Display assistant response
112
- with st.chat_message("assistant"):
113
- st.write(response_text)
114
-
115
- # Release the lock
116
- st.session_state["is_in_request"] = False
 
1
  import os
2
  import re
3
+ import io
4
+ import time
5
+ import json
6
+ import queue
7
+ import logging
8
+ from typing import Any, Generator, Optional, List, Dict, Tuple
9
+ from dataclasses import dataclass
10
+
11
  import streamlit as st
 
12
  from dotenv import load_dotenv
13
+ from PIL import Image
14
+ import openai
15
+ from langsmith.wrappers import wrap_openai
16
+ from langsmith import traceable
17
+
18
+ # ------------------------
19
+ # Configuration and Types
20
+ # ------------------------
21
+ @dataclass
22
+ class AppConfig:
23
+ """Application configuration settings."""
24
+ pass
25
+
26
+ @dataclass
27
+ class Message:
28
+ """Chat message structure."""
29
+ role: str
30
+ content: str
31
+
32
+ class StreamingError(Exception):
33
+ """Custom exception for streaming-related errors."""
34
+ pass
35
+
36
+ # ------------------------
37
+ # Logging Configuration
38
+ # ------------------------
39
+ def setup_logging() -> logging.Logger:
40
+ """Configure and return the application logger."""
41
+ logging.basicConfig(
42
+ format="[%(asctime)s] %(levelname)+8s: %(message)s",
43
+ level=logging.INFO,
44
+ )
45
+ return logging.getLogger(__name__)
46
+
47
+ logger = setup_logging()
48
+
49
+ # ------------------------
50
+ # Environment Setup
51
+ # ------------------------
52
+ class EnvironmentManager:
53
+ """Manages environment variables and configuration."""
54
+
55
+ @staticmethod
56
+ def load_environment() -> Tuple[str, str]:
57
+ """Load and validate environment variables."""
58
+ load_dotenv(override=True)
59
+ api_key = os.getenv("OPENAI_API_KEY")
60
+ assistant_id = os.getenv("ASSISTANT_ID_SOLUTION_SPECIFIER_A")
61
+
62
+ if not api_key or not assistant_id:
63
+ raise RuntimeError(
64
+ "Missing required environment variables. Please set "
65
+ "OPENAI_API_KEY and ASSISTANT_ID_SOLUTION_SPECIFIER_A"
66
  )
67
+
68
+ return api_key, assistant_id
69
 
70
+ # ------------------------
71
+ # State Management
72
+ # ------------------------
73
+ class StateManager:
74
+ """Manages Streamlit session state."""
75
+
76
+ @staticmethod
77
+ def initialize_state() -> None:
78
+ """Initialize session state variables."""
79
+ if "messages" not in st.session_state:
80
+ st.session_state.messages = []
81
+ if "thread" not in st.session_state:
82
+ st.session_state.thread = None
83
+ if "tool_requests" not in st.session_state:
84
+ st.session_state.tool_requests = queue.Queue()
85
+ if "run_stream" not in st.session_state:
86
+ st.session_state.run_stream = None
87
 
88
+ @staticmethod
89
+ def add_message(role: str, content: str) -> None:
90
+ """Add a message to the conversation history."""
91
+ st.session_state.messages.append(Message(role=role, content=content))
92
+
93
+ # ------------------------
94
+ # Text Processing
95
+ # ------------------------
96
+ class TextProcessor:
97
+ """Handles text processing and formatting."""
98
+
99
+ @staticmethod
100
+ def remove_citations(text: str) -> str:
101
+ """Remove citation markers from text."""
102
+ pattern = r"【\d+†\w+】"
103
+ return re.sub(pattern, "📚", text)
104
+
105
+ # ------------------------
106
+ # Streaming Handler
107
+ # ------------------------
108
+ class StreamHandler:
109
+ """Handles streaming of assistant responses."""
110
+
111
+ def __init__(self, client: Any):
112
+ self.client = client
113
+ self.text_processor = TextProcessor()
114
+ self.complete_response = []
115
+
116
+ def stream_data(self) -> Generator[Any, None, None]:
117
+ """Stream data from the assistant run."""
118
+ st.toast("Thinking...", icon="🤔")
119
+ content_produced = False
120
+ self.complete_response = [] # Reset for new stream
121
+
122
+ try:
123
+ for event in st.session_state.run_stream:
124
+ match event.event:
125
+ case "thread.message.delta":
126
+ yield from self._handle_message_delta(event, content_produced)
127
+ case "thread.run.requires_action":
128
+ yield from self._handle_action_request(event, content_produced)
129
+ case "thread.run.failed":
130
+ logger.error(f"Run failed: {event}")
131
+ raise StreamingError(f"Assistant run failed: {event}")
132
+
133
+ st.toast("Completed", icon="✅")
134
+ # Return the complete response for storage
135
+ return "".join(self.complete_response)
136
+ except Exception as e:
137
+ logger.error(f"Streaming error: {e}")
138
+ st.error(f"An error occurred while streaming: {str(e)}")
139
+ raise
140
+
141
+ def _handle_message_delta(self, event: Any, content_produced: bool) -> Generator[Any, None, None]:
142
+ """Handle message delta events."""
143
+ content = event.data.delta.content[0]
144
+ match content.type:
145
+ case "text":
146
+ processed_text = self.text_processor.remove_citations(content.text.value)
147
+ self.complete_response.append(processed_text) # Store the chunk
148
+ yield processed_text
149
+ case "image_file":
150
+ image_content = io.BytesIO(self.client.files.content(content.image_file.file_id).read())
151
+ yield Image.open(image_content)
152
+
153
+ def _handle_action_request(self, event: Any, content_produced: bool) -> Generator[str, None, None]:
154
+ """Handle action request events."""
155
+ logger.info(f"[Tool Request] {event}")
156
+ st.session_state.tool_requests.put(event)
157
+ if not content_produced:
158
+ yield "[Processing function call...]"
159
+
160
+ # ------------------------
161
+ # Tool Request Handler
162
+ # ------------------------
163
+ class ToolRequestHandler:
164
+ """Handles tool requests from the assistant."""
165
+
166
+ @staticmethod
167
+ def handle_request(event: Any) -> Tuple[List[Dict[str, str]], str, str]:
168
+ """Process tool requests and return outputs."""
169
+ st.toast("Processing function call...", icon="⚙️")
170
+ tool_outputs = []
171
+ data = event.data
172
+
173
+ for tool_call in data.required_action.submit_tool_outputs.tool_calls:
174
+ output = ToolRequestHandler._process_tool_call(tool_call)
175
+ tool_outputs.append(output)
176
+
177
+ return tool_outputs, data.thread_id, data.id
178
+
179
+ @staticmethod
180
+ def _process_tool_call(tool_call: Any) -> Dict[str, str]:
181
+ """Process individual tool calls."""
182
+ function_args = json.loads(tool_call.function.arguments) if tool_call.function.arguments else {}
183
+
184
+ match tool_call.function.name:
185
+ case "hello_world":
186
+ name = function_args.get("name", "anonymous")
187
+ output_val = f"Hello, {name}! This was from a local function."
188
+ case _:
189
+ output_val = json.dumps({"status": "error", "message": "Unknown function request."})
190
+
191
+ return {"tool_call_id": tool_call.id, "output": output_val}
192
+
193
+ # ------------------------
194
+ # Assistant Manager
195
+ # ------------------------
196
+ class AssistantManager:
197
+ """Manages interactions with the OpenAI Assistant."""
198
+
199
+ def __init__(self, client: Any, assistant_id: str):
200
+ self.client = client
201
+ self.assistant_id = assistant_id
202
+ self.stream_handler = StreamHandler(client)
203
+ self.tool_handler = ToolRequestHandler()
204
+
205
+ @traceable
206
+ def generate_reply(self, user_input: str) -> str:
207
+ """Generate and stream assistant's reply."""
208
+ # Ensure thread exists
209
+ if not st.session_state.thread:
210
+ st.session_state.thread = self.client.beta.threads.create()
211
+
212
+ # Add user message
213
+ self.client.beta.threads.messages.create(
214
+ thread_id=st.session_state.thread.id,
215
+ role="user",
216
+ content=user_input
217
+ )
218
+
219
+ complete_response = ""
220
+
221
+ # Stream initial response
222
+ with self.client.beta.threads.runs.stream(
223
+ thread_id=st.session_state.thread.id,
224
+ assistant_id=self.assistant_id,
225
+ ) as run_stream:
226
+ complete_response = self._display_stream(run_stream)
227
+
228
+ # Handle any tool requests
229
+ self._process_tool_requests()
230
+
231
+ return complete_response
232
+
233
+ def _display_stream(self, run_stream: Any, create_context: bool = True) -> str:
234
+ """Display streaming content."""
235
+ st.session_state.run_stream = run_stream
236
+ if create_context:
237
+ with st.chat_message("assistant"):
238
+ return st.write_stream(self.stream_handler.stream_data)
239
  else:
240
+ return st.write_stream(self.stream_handler.stream_data)
241
+
242
+ def _process_tool_requests(self) -> None:
243
+ """Process any pending tool requests."""
244
+ while not st.session_state.tool_requests.empty():
245
+ event = st.session_state.tool_requests.get()
246
+ tool_outputs, thread_id, run_id = self.tool_handler.handle_request(event)
247
+
248
+ with self.client.beta.threads.runs.submit_tool_outputs_stream(
249
+ thread_id=thread_id,
250
+ run_id=run_id,
251
+ tool_outputs=tool_outputs
252
+ ) as next_stream:
253
+ self._display_stream(next_stream, create_context=False)
254
+
255
+ # ------------------------
256
+ # Main Application
257
+ # ------------------------
258
+ class ChatApplication:
259
+ """Main chat application class."""
260
+
261
+ def __init__(self):
262
+ self.config = AppConfig()
263
+ api_key, assistant_id = EnvironmentManager.load_environment()
264
+
265
+ # Initialize OpenAI client
266
+ openai_client = openai.Client(api_key=api_key)
267
+ self.client = wrap_openai(openai_client)
268
+
269
+ # Initialize components
270
+ self.state_manager = StateManager()
271
+ self.assistant_manager = AssistantManager(self.client, assistant_id)
272
+
273
+ def setup_page(self) -> None:
274
+ """Configure the Streamlit page."""
275
+ st.set_page_config(
276
+ page_title=self.config.page_title,
277
+ page_icon=self.config.page_icon,
278
+ layout=self.config.layout
279
+ )
280
+ st.title(self.config.page_title)
281
+
282
+ def display_chat_history(self) -> None:
283
+ """Display the chat history."""
284
+ for msg in st.session_state.messages:
285
+ with st.chat_message(msg.role):
286
+ st.write(msg.content)
287
+
288
+ def run(self) -> None:
289
+ """Run the chat application."""
290
+ self.setup_page()
291
+ self.state_manager.initialize_state()
292
+ self.display_chat_history()
293
+
294
+ user_input = st.chat_input("Type your message here...")
295
+ if user_input:
296
+ # Display and store user message
297
+ with st.chat_message("user"):
298
+ st.write(user_input)
299
+ self.state_manager.add_message("user", user_input)
300
+
301
+ # Generate and display assistant reply
302
+ try:
303
+ complete_response = self.assistant_manager.generate_reply(user_input)
304
+ self.state_manager.add_message(
305
+ "assistant",
306
+ complete_response
307
+ )
308
+ except Exception as e:
309
+ st.error(f"Error generating response: {str(e)}")
310
+ logger.exception("Error in assistant reply generation")
311
+
312
+ def main():
313
+ """Application entry point."""
314
+ try:
315
+ app = ChatApplication()
316
+ app.run()
317
  except Exception as e:
318
+ st.error(f"Application error: {str(e)}")
319
+ logger.exception("Fatal application error")
320
+
321
+ if __name__ == "__main__":
322
+ main()