File size: 9,140 Bytes
a0d7d94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad02e51
a0d7d94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
import os
import logging
import json
from openai import OpenAI
from dotenv import load_dotenv
from data_service import DataService

# Load environment variables from .env file
load_dotenv(override=True)

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)


class OpenAIService:
    def __init__(self, api_key=None, assistant_id=None, data_dir="data"):
        """
        Initialize OpenAI service with Assistant API.
        
        Args:
            api_key: OpenAI API key (defaults to OPENAI_API_KEY env var)
            assistant_id: OpenAI Assistant ID (defaults to ASSISTANT_ID env var)
            data_dir: Path to data directory for DataService (default: "data")
        """
        logger.info("Initializing OpenAI service...")
        
        self.api_key = api_key or os.getenv("OPENAI_API_KEY")
        self.assistant_id = assistant_id or os.getenv("ASSISTANT_ID")
        
        if not self.api_key:
            logger.error("OpenAI API key not found in environment variables")
            raise ValueError("OpenAI API key is required. Set OPENAI_API_KEY environment variable.")
        if not self.assistant_id:
            logger.error("Assistant ID not found in environment variables")
            raise ValueError("Assistant ID is required. Set ASSISTANT_ID environment variable.")
        
        logger.info(f"API key loaded (length: {len(self.api_key)})")
        logger.info(f"Assistant ID: {self.assistant_id}")
        
        self.client = OpenAI(api_key=self.api_key)
        self.thread = None
        self.data_service = DataService(data_dir)
        
        logger.info("OpenAI service initialized successfully")
    
    def create_thread(self):
        """Create a new conversation thread."""
        logger.info("Creating new conversation thread...")
        self.thread = self.client.beta.threads.create()
        logger.info(f"Thread created successfully: {self.thread.id}")
        return self.thread.id
    
    def get_or_create_thread(self):
        """Get existing thread or create a new one."""
        if not self.thread:
            logger.info("No existing thread found, creating new thread")
            self.create_thread()
        else:
            logger.info(f"Using existing thread: {self.thread.id}")
        return self.thread.id
    
    def execute_tool_call(self, tool_name, tool_arguments):
        """
        Execute a tool call and return the result.
        
        Args:
            tool_name: Name of the tool to execute
            tool_arguments: Dictionary of arguments for the tool
            
        Returns:
            str: JSON string with the tool execution result in format { success: bool, result/error: data }
        """
        logger.info(f"Executing tool: {tool_name} with arguments: {tool_arguments}")
        
        try:
            if tool_name == "get_real_time_commissions":
                result = self.data_service.get_data("GetRealTimeCommissions.json")
                if result is None:
                    return json.dumps({"success": False, "error": f"Failed to load data from GetRealTimeCommissions.json"})
                
                return json.dumps({"success": True, "result": result})
            elif tool_name == "get_volumes":
                result = self.data_service.get_data("GetVolumes.json")
                if result is None:
                    return json.dumps({"success": False, "error": f"Failed to load data from GetVolumes.json"})
                
                return json.dumps({"success": True, "result": result})
            elif tool_name == "get_customer":
                result = self.data_service.get_data("GetCustomers.json")
                if result is None:
                    return json.dumps({"success": False, "error": f"Failed to load data from GetCustomers.json"})
                
                return json.dumps({"success": True, "result": result})
            else:
                logger.warning(f"Unknown tool: {tool_name}")
                return json.dumps({"success": False, "error": f"Unknown tool: {tool_name}"})
        
        except Exception as e:
            logger.error(f"Error executing tool {tool_name}: {str(e)}", exc_info=True)
            return json.dumps({"success": False, "error": str(e)})
    
    def generate_stream(self, message):
        """
        Generate response from assistant with streaming and tool handling.
        
        Args:
            message: User's message string
            
        Yields:
            str: Chunks of the response as they arrive
        """
        try:
            logger.info(f"Processing message: {message[:50]}...")
            thread_id = self.get_or_create_thread()
            
            # Add user message to thread
            logger.info(f"Adding message to thread {thread_id}")
            self.client.beta.threads.messages.create(
                thread_id=thread_id,
                role="user",
                content=message
            )
            logger.info("Message added successfully")
            
            # Stream the assistant's response
            logger.info("Starting assistant response stream...")
            chunk_count = 0
            skipped_annotations = 0
            
            with self.client.beta.threads.runs.stream(
                thread_id=thread_id,
                assistant_id=self.assistant_id
            ) as stream:
                for event in stream:
                    # Handle text streaming
                    if event.event == "thread.message.delta":
                        for content in event.data.delta.content:
                            if hasattr(content, 'text') and hasattr(content.text, 'value'):
                                if hasattr(content.text, 'annotations') and content.text.annotations:
                                    skipped_annotations += 1
                                    continue
                                
                                chunk_count += 1
                                yield content.text.value
                    
                    # Handle tool calls
                    elif event.event == "thread.run.requires_action":
                        logger.info("Assistant requires action (tool calls)")
                        run_id = event.data.id
                        tool_calls = event.data.required_action.submit_tool_outputs.tool_calls
                        
                        tool_outputs = []
                        for tool_call in tool_calls:
                            logger.info(f"Processing tool call: {tool_call.function.name}")
                            
                            tool_arguments = json.loads(tool_call.function.arguments)
                            tool_output = self.execute_tool_call(
                                tool_call.function.name,
                                tool_arguments
                            )
                            
                            tool_outputs.append({
                                "tool_call_id": tool_call.id,
                                "output": tool_output
                            })
                        
                        # Submit tool outputs and continue streaming
                        logger.info(f"Submitting {len(tool_outputs)} tool outputs")
                        with self.client.beta.threads.runs.submit_tool_outputs_stream(
                            thread_id=thread_id,
                            run_id=run_id,
                            tool_outputs=tool_outputs
                        ) as tool_stream:
                            for tool_event in tool_stream:
                                if tool_event.event == "thread.message.delta":
                                    for content in tool_event.data.delta.content:
                                        if hasattr(content, 'text') and hasattr(content.text, 'value'):
                                            if hasattr(content.text, 'annotations') and content.text.annotations:
                                                skipped_annotations += 1
                                                continue
                                            
                                            chunk_count += 1
                                            yield content.text.value
            
            logger.info(f"Stream completed. Chunks received: {chunk_count}, Annotations skipped: {skipped_annotations}")
                                
        except Exception as e:
            logger.error(f"Error in generate_stream: {str(e)}", exc_info=True)
            yield f"Error: {str(e)}"
    
    def clear_thread(self):
        """Clear the current thread by creating a new one."""
        if self.thread:
            logger.info(f"Clearing thread: {self.thread.id}")
        else:
            logger.info("No active thread to clear")
        self.thread = None
        logger.info("Thread cleared. New thread will be created on next message")