Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import io | |
| import time | |
| import json | |
| import queue | |
| import logging | |
| from typing import Any, Generator, Optional, List, Dict, Tuple | |
| from dataclasses import dataclass | |
| import streamlit as st | |
| from dotenv import load_dotenv | |
| from PIL import Image | |
| import openai | |
| from langsmith.wrappers import wrap_openai | |
| from langsmith import traceable | |
| # ------------------------ | |
| # Configuration and Types | |
| # ------------------------ | |
| class AppConfig: | |
| """Application configuration settings.""" | |
| page_title: str = "Is it a Match?" | |
| page_icon: str = "👀" | |
| layout: str = "centered" | |
| class Message: | |
| """Chat message structure.""" | |
| role: str | |
| content: str | |
| class StreamingError(Exception): | |
| """Custom exception for streaming-related errors.""" | |
| pass | |
| # ------------------------ | |
| # Logging Configuration | |
| # ------------------------ | |
| def setup_logging() -> logging.Logger: | |
| """Configure and return the application logger.""" | |
| logging.basicConfig( | |
| format="[%(asctime)s] %(levelname)+8s: %(message)s", | |
| level=logging.INFO, | |
| ) | |
| return logging.getLogger(__name__) | |
| logger = setup_logging() | |
| # ------------------------ | |
| # Environment Setup | |
| # ------------------------ | |
| class EnvironmentManager: | |
| """Manages environment variables and configuration.""" | |
| def load_environment() -> Tuple[str, str]: | |
| """Load and validate environment variables.""" | |
| load_dotenv(override=True) | |
| api_key = os.getenv("OPENAI_API_KEY") | |
| assistant_id = os.getenv("ASSISTANT_ID_SOLUTION_SPECIFIER_A") | |
| if not api_key or not assistant_id: | |
| raise RuntimeError( | |
| "Missing required environment variables. Please set " | |
| "OPENAI_API_KEY and ASSISTANT_ID_SOLUTION_SPECIFIER_A" | |
| ) | |
| return api_key, assistant_id | |
| # ------------------------ | |
| # State Management | |
| # ------------------------ | |
| class StateManager: | |
| """Manages Streamlit session state.""" | |
| def initialize_state() -> None: | |
| """Initialize session state variables.""" | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| if "thread" not in st.session_state: | |
| st.session_state.thread = None | |
| if "tool_requests" not in st.session_state: | |
| st.session_state.tool_requests = queue.Queue() | |
| if "run_stream" not in st.session_state: | |
| st.session_state.run_stream = None | |
| def add_message(role: str, content: str) -> None: | |
| """Add a message to the conversation history.""" | |
| st.session_state.messages.append(Message(role=role, content=content)) | |
| # ------------------------ | |
| # Text Processing | |
| # ------------------------ | |
| class TextProcessor: | |
| """Handles text processing and formatting.""" | |
| def remove_citations(text: str) -> str: | |
| """Remove citation markers from text.""" | |
| pattern = r"【\d+†\w+】" | |
| return re.sub(pattern, "📚", text) | |
| # ------------------------ | |
| # Streaming Handler | |
| # ------------------------ | |
| class StreamHandler: | |
| """Handles streaming of assistant responses.""" | |
| def __init__(self, client: Any): | |
| self.client = client | |
| self.text_processor = TextProcessor() | |
| self.complete_response = [] | |
| def stream_data(self) -> Generator[Any, None, None]: | |
| """Stream data from the assistant run.""" | |
| st.toast("Thinking...", icon="🤔") | |
| content_produced = False | |
| self.complete_response = [] # Reset for new stream | |
| try: | |
| for event in st.session_state.run_stream: | |
| match event.event: | |
| case "thread.message.delta": | |
| yield from self._handle_message_delta(event, content_produced) | |
| case "thread.run.requires_action": | |
| yield from self._handle_action_request(event, content_produced) | |
| case "thread.run.failed": | |
| logger.error(f"Run failed: {event}") | |
| raise StreamingError(f"Assistant run failed: {event}") | |
| st.toast("Completed", icon="✅") | |
| # Return the complete response for storage | |
| return "".join(self.complete_response) | |
| except Exception as e: | |
| logger.error(f"Streaming error: {e}") | |
| st.error(f"An error occurred while streaming: {str(e)}") | |
| raise | |
| def _handle_message_delta(self, event: Any, content_produced: bool) -> Generator[Any, None, None]: | |
| """Handle message delta events.""" | |
| content = event.data.delta.content[0] | |
| match content.type: | |
| case "text": | |
| processed_text = self.text_processor.remove_citations(content.text.value) | |
| self.complete_response.append(processed_text) # Store the chunk | |
| yield processed_text | |
| case "image_file": | |
| image_content = io.BytesIO(self.client.files.content(content.image_file.file_id).read()) | |
| yield Image.open(image_content) | |
| def _handle_action_request(self, event: Any, content_produced: bool) -> Generator[str, None, None]: | |
| """Handle action request events.""" | |
| logger.info(f"[Tool Request] {event}") | |
| st.session_state.tool_requests.put(event) | |
| if not content_produced: | |
| yield "[Processing function call...]" | |
| # ------------------------ | |
| # Tool Request Handler | |
| # ------------------------ | |
| class ToolRequestHandler: | |
| """Handles tool requests from the assistant.""" | |
| def handle_request(event: Any) -> Tuple[List[Dict[str, str]], str, str]: | |
| """Process tool requests and return outputs.""" | |
| st.toast("Processing function call...", icon="⚙️") | |
| tool_outputs = [] | |
| data = event.data | |
| for tool_call in data.required_action.submit_tool_outputs.tool_calls: | |
| output = ToolRequestHandler._process_tool_call(tool_call) | |
| tool_outputs.append(output) | |
| return tool_outputs, data.thread_id, data.id | |
| def _process_tool_call(tool_call: Any) -> Dict[str, str]: | |
| """Process individual tool calls.""" | |
| function_args = json.loads(tool_call.function.arguments) if tool_call.function.arguments else {} | |
| match tool_call.function.name: | |
| case "hello_world": | |
| name = function_args.get("name", "anonymous") | |
| output_val = f"Hello, {name}! This was from a local function." | |
| case _: | |
| output_val = json.dumps({"status": "error", "message": "Unknown function request."}) | |
| return {"tool_call_id": tool_call.id, "output": output_val} | |
| # ------------------------ | |
| # Assistant Manager | |
| # ------------------------ | |
| class AssistantManager: | |
| """Manages interactions with the OpenAI Assistant.""" | |
| def __init__(self, client: Any, assistant_id: str): | |
| self.client = client | |
| self.assistant_id = assistant_id | |
| self.stream_handler = StreamHandler(client) | |
| self.tool_handler = ToolRequestHandler() | |
| def generate_reply(self, user_input: str) -> str: | |
| """Generate and stream assistant's reply.""" | |
| # Ensure thread exists | |
| if not st.session_state.thread: | |
| st.session_state.thread = self.client.beta.threads.create() | |
| # Add user message | |
| self.client.beta.threads.messages.create( | |
| thread_id=st.session_state.thread.id, | |
| role="user", | |
| content=user_input | |
| ) | |
| complete_response = "" | |
| # Stream initial response | |
| with self.client.beta.threads.runs.stream( | |
| thread_id=st.session_state.thread.id, | |
| assistant_id=self.assistant_id, | |
| ) as run_stream: | |
| complete_response = self._display_stream(run_stream) | |
| # Handle any tool requests | |
| self._process_tool_requests() | |
| return complete_response | |
| def _display_stream(self, run_stream: Any, create_context: bool = True) -> str: | |
| """Display streaming content.""" | |
| st.session_state.run_stream = run_stream | |
| if create_context: | |
| with st.chat_message("assistant"): | |
| return st.write_stream(self.stream_handler.stream_data) | |
| else: | |
| return st.write_stream(self.stream_handler.stream_data) | |
| def _process_tool_requests(self) -> None: | |
| """Process any pending tool requests.""" | |
| while not st.session_state.tool_requests.empty(): | |
| event = st.session_state.tool_requests.get() | |
| tool_outputs, thread_id, run_id = self.tool_handler.handle_request(event) | |
| with self.client.beta.threads.runs.submit_tool_outputs_stream( | |
| thread_id=thread_id, | |
| run_id=run_id, | |
| tool_outputs=tool_outputs | |
| ) as next_stream: | |
| self._display_stream(next_stream, create_context=False) | |
| # ------------------------ | |
| # Main Application | |
| # ------------------------ | |
| class ChatApplication: | |
| """Main chat application class.""" | |
| def __init__(self): | |
| self.config = AppConfig() | |
| api_key, assistant_id = EnvironmentManager.load_environment() | |
| # Initialize OpenAI client | |
| openai_client = openai.Client(api_key=api_key) | |
| self.client = wrap_openai(openai_client) | |
| # Initialize components | |
| self.state_manager = StateManager() | |
| self.assistant_manager = AssistantManager(self.client, assistant_id) | |
| def setup_page(self) -> None: | |
| """Configure the Streamlit page.""" | |
| st.set_page_config( | |
| page_title=self.config.page_title, | |
| page_icon=self.config.page_icon, | |
| layout=self.config.layout | |
| ) | |
| st.title(self.config.page_title) | |
| def display_chat_history(self) -> None: | |
| """Display the chat history.""" | |
| for msg in st.session_state.messages: | |
| with st.chat_message(msg.role): | |
| st.write(msg.content) | |
| def run(self) -> None: | |
| """Run the chat application.""" | |
| self.setup_page() | |
| self.state_manager.initialize_state() | |
| self.display_chat_history() | |
| user_input = st.chat_input("Type your message here...") | |
| if user_input: | |
| # Display and store user message | |
| with st.chat_message("user"): | |
| st.write(user_input) | |
| self.state_manager.add_message("user", user_input) | |
| # Generate and display assistant reply | |
| try: | |
| complete_response = self.assistant_manager.generate_reply(user_input) | |
| self.state_manager.add_message( | |
| "assistant", | |
| complete_response | |
| ) | |
| except Exception as e: | |
| st.error(f"Error generating response: {str(e)}") | |
| logger.exception("Error in assistant reply generation") | |
| def main(): | |
| """Application entry point.""" | |
| try: | |
| app = ChatApplication() | |
| app.run() | |
| except Exception as e: | |
| st.error(f"Application error: {str(e)}") | |
| logger.exception("Fatal application error") | |
| if __name__ == "__main__": | |
| main() |