Spaces:
Sleeping
Sleeping
| import os | |
| import logging | |
| import json | |
| from openai import OpenAI | |
| from dotenv import load_dotenv | |
| from data_service import DataService | |
| # Load environment variables from .env file | |
| load_dotenv(override=True) | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| class OpenAIService: | |
| def __init__(self, api_key=None, assistant_id=None, data_dir="data"): | |
| """ | |
| Initialize OpenAI service with Assistant API. | |
| Args: | |
| api_key: OpenAI API key (defaults to OPENAI_API_KEY env var) | |
| assistant_id: OpenAI Assistant ID (defaults to ASSISTANT_ID env var) | |
| data_dir: Path to data directory for DataService (default: "data") | |
| """ | |
| logger.info("Initializing OpenAI service...") | |
| self.api_key = api_key or os.getenv("OPENAI_API_KEY") | |
| self.assistant_id = assistant_id or os.getenv("ASSISTANT_ID") | |
| if not self.api_key: | |
| logger.error("OpenAI API key not found in environment variables") | |
| raise ValueError("OpenAI API key is required. Set OPENAI_API_KEY environment variable.") | |
| if not self.assistant_id: | |
| logger.error("Assistant ID not found in environment variables") | |
| raise ValueError("Assistant ID is required. Set ASSISTANT_ID environment variable.") | |
| logger.info(f"API key loaded (length: {len(self.api_key)})") | |
| logger.info(f"Assistant ID: {self.assistant_id}") | |
| self.client = OpenAI(api_key=self.api_key) | |
| self.thread = None | |
| self.data_service = DataService(data_dir) | |
| logger.info("OpenAI service initialized successfully") | |
| def create_thread(self): | |
| """Create a new conversation thread.""" | |
| logger.info("Creating new conversation thread...") | |
| self.thread = self.client.beta.threads.create() | |
| logger.info(f"Thread created successfully: {self.thread.id}") | |
| return self.thread.id | |
| def get_or_create_thread(self): | |
| """Get existing thread or create a new one.""" | |
| if not self.thread: | |
| logger.info("No existing thread found, creating new thread") | |
| self.create_thread() | |
| else: | |
| logger.info(f"Using existing thread: {self.thread.id}") | |
| return self.thread.id | |
| def execute_tool_call(self, tool_name, tool_arguments): | |
| """ | |
| Execute a tool call and return the result. | |
| Args: | |
| tool_name: Name of the tool to execute | |
| tool_arguments: Dictionary of arguments for the tool | |
| Returns: | |
| str: JSON string with the tool execution result in format { success: bool, result/error: data } | |
| """ | |
| logger.info(f"Executing tool: {tool_name} with arguments: {tool_arguments}") | |
| try: | |
| if tool_name == "get_real_time_commissions": | |
| result = self.data_service.get_data("GetRealTimeCommissions.json") | |
| if result is None: | |
| return json.dumps({"success": False, "error": f"Failed to load data from GetRealTimeCommissions.json"}) | |
| return json.dumps({"success": True, "result": result}) | |
| elif tool_name == "get_volumes": | |
| result = self.data_service.get_data("GetVolumes.json") | |
| if result is None: | |
| return json.dumps({"success": False, "error": f"Failed to load data from GetVolumes.json"}) | |
| return json.dumps({"success": True, "result": result}) | |
| elif tool_name == "get_customer": | |
| result = self.data_service.get_data("GetCustomers.json") | |
| if result is None: | |
| return json.dumps({"success": False, "error": f"Failed to load data from GetCustomers.json"}) | |
| return json.dumps({"success": True, "result": result}) | |
| else: | |
| logger.warning(f"Unknown tool: {tool_name}") | |
| return json.dumps({"success": False, "error": f"Unknown tool: {tool_name}"}) | |
| except Exception as e: | |
| logger.error(f"Error executing tool {tool_name}: {str(e)}", exc_info=True) | |
| return json.dumps({"success": False, "error": str(e)}) | |
| def generate_stream(self, message): | |
| """ | |
| Generate response from assistant with streaming and tool handling. | |
| Args: | |
| message: User's message string | |
| Yields: | |
| str: Chunks of the response as they arrive | |
| """ | |
| try: | |
| logger.info(f"Processing message: {message[:50]}...") | |
| thread_id = self.get_or_create_thread() | |
| # Add user message to thread | |
| logger.info(f"Adding message to thread {thread_id}") | |
| self.client.beta.threads.messages.create( | |
| thread_id=thread_id, | |
| role="user", | |
| content=message | |
| ) | |
| logger.info("Message added successfully") | |
| # Stream the assistant's response | |
| logger.info("Starting assistant response stream...") | |
| chunk_count = 0 | |
| skipped_annotations = 0 | |
| with self.client.beta.threads.runs.stream( | |
| thread_id=thread_id, | |
| assistant_id=self.assistant_id | |
| ) as stream: | |
| for event in stream: | |
| # Handle text streaming | |
| if event.event == "thread.message.delta": | |
| for content in event.data.delta.content: | |
| if hasattr(content, 'text') and hasattr(content.text, 'value'): | |
| if hasattr(content.text, 'annotations') and content.text.annotations: | |
| skipped_annotations += 1 | |
| continue | |
| chunk_count += 1 | |
| yield content.text.value | |
| # Handle tool calls | |
| elif event.event == "thread.run.requires_action": | |
| logger.info("Assistant requires action (tool calls)") | |
| run_id = event.data.id | |
| tool_calls = event.data.required_action.submit_tool_outputs.tool_calls | |
| tool_outputs = [] | |
| for tool_call in tool_calls: | |
| logger.info(f"Processing tool call: {tool_call.function.name}") | |
| tool_arguments = json.loads(tool_call.function.arguments) | |
| tool_output = self.execute_tool_call( | |
| tool_call.function.name, | |
| tool_arguments | |
| ) | |
| tool_outputs.append({ | |
| "tool_call_id": tool_call.id, | |
| "output": tool_output | |
| }) | |
| # Submit tool outputs and continue streaming | |
| logger.info(f"Submitting {len(tool_outputs)} tool outputs") | |
| with self.client.beta.threads.runs.submit_tool_outputs_stream( | |
| thread_id=thread_id, | |
| run_id=run_id, | |
| tool_outputs=tool_outputs | |
| ) as tool_stream: | |
| for tool_event in tool_stream: | |
| if tool_event.event == "thread.message.delta": | |
| for content in tool_event.data.delta.content: | |
| if hasattr(content, 'text') and hasattr(content.text, 'value'): | |
| if hasattr(content.text, 'annotations') and content.text.annotations: | |
| skipped_annotations += 1 | |
| continue | |
| chunk_count += 1 | |
| yield content.text.value | |
| logger.info(f"Stream completed. Chunks received: {chunk_count}, Annotations skipped: {skipped_annotations}") | |
| except Exception as e: | |
| logger.error(f"Error in generate_stream: {str(e)}", exc_info=True) | |
| yield f"Error: {str(e)}" | |
| def clear_thread(self): | |
| """Clear the current thread by creating a new one.""" | |
| if self.thread: | |
| logger.info(f"Clearing thread: {self.thread.id}") | |
| else: | |
| logger.info("No active thread to clear") | |
| self.thread = None | |
| logger.info("Thread cleared. New thread will be created on next message") | |