Spaces:
Sleeping
Sleeping
File size: 9,140 Bytes
a0d7d94 ad02e51 a0d7d94 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 | import os
import logging
import json
from openai import OpenAI
from dotenv import load_dotenv
from data_service import DataService
# Load environment variables from .env file
load_dotenv(override=True)
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
class OpenAIService:
def __init__(self, api_key=None, assistant_id=None, data_dir="data"):
"""
Initialize OpenAI service with Assistant API.
Args:
api_key: OpenAI API key (defaults to OPENAI_API_KEY env var)
assistant_id: OpenAI Assistant ID (defaults to ASSISTANT_ID env var)
data_dir: Path to data directory for DataService (default: "data")
"""
logger.info("Initializing OpenAI service...")
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
self.assistant_id = assistant_id or os.getenv("ASSISTANT_ID")
if not self.api_key:
logger.error("OpenAI API key not found in environment variables")
raise ValueError("OpenAI API key is required. Set OPENAI_API_KEY environment variable.")
if not self.assistant_id:
logger.error("Assistant ID not found in environment variables")
raise ValueError("Assistant ID is required. Set ASSISTANT_ID environment variable.")
logger.info(f"API key loaded (length: {len(self.api_key)})")
logger.info(f"Assistant ID: {self.assistant_id}")
self.client = OpenAI(api_key=self.api_key)
self.thread = None
self.data_service = DataService(data_dir)
logger.info("OpenAI service initialized successfully")
def create_thread(self):
"""Create a new conversation thread."""
logger.info("Creating new conversation thread...")
self.thread = self.client.beta.threads.create()
logger.info(f"Thread created successfully: {self.thread.id}")
return self.thread.id
def get_or_create_thread(self):
"""Get existing thread or create a new one."""
if not self.thread:
logger.info("No existing thread found, creating new thread")
self.create_thread()
else:
logger.info(f"Using existing thread: {self.thread.id}")
return self.thread.id
def execute_tool_call(self, tool_name, tool_arguments):
"""
Execute a tool call and return the result.
Args:
tool_name: Name of the tool to execute
tool_arguments: Dictionary of arguments for the tool
Returns:
str: JSON string with the tool execution result in format { success: bool, result/error: data }
"""
logger.info(f"Executing tool: {tool_name} with arguments: {tool_arguments}")
try:
if tool_name == "get_real_time_commissions":
result = self.data_service.get_data("GetRealTimeCommissions.json")
if result is None:
return json.dumps({"success": False, "error": f"Failed to load data from GetRealTimeCommissions.json"})
return json.dumps({"success": True, "result": result})
elif tool_name == "get_volumes":
result = self.data_service.get_data("GetVolumes.json")
if result is None:
return json.dumps({"success": False, "error": f"Failed to load data from GetVolumes.json"})
return json.dumps({"success": True, "result": result})
elif tool_name == "get_customer":
result = self.data_service.get_data("GetCustomers.json")
if result is None:
return json.dumps({"success": False, "error": f"Failed to load data from GetCustomers.json"})
return json.dumps({"success": True, "result": result})
else:
logger.warning(f"Unknown tool: {tool_name}")
return json.dumps({"success": False, "error": f"Unknown tool: {tool_name}"})
except Exception as e:
logger.error(f"Error executing tool {tool_name}: {str(e)}", exc_info=True)
return json.dumps({"success": False, "error": str(e)})
def generate_stream(self, message):
"""
Generate response from assistant with streaming and tool handling.
Args:
message: User's message string
Yields:
str: Chunks of the response as they arrive
"""
try:
logger.info(f"Processing message: {message[:50]}...")
thread_id = self.get_or_create_thread()
# Add user message to thread
logger.info(f"Adding message to thread {thread_id}")
self.client.beta.threads.messages.create(
thread_id=thread_id,
role="user",
content=message
)
logger.info("Message added successfully")
# Stream the assistant's response
logger.info("Starting assistant response stream...")
chunk_count = 0
skipped_annotations = 0
with self.client.beta.threads.runs.stream(
thread_id=thread_id,
assistant_id=self.assistant_id
) as stream:
for event in stream:
# Handle text streaming
if event.event == "thread.message.delta":
for content in event.data.delta.content:
if hasattr(content, 'text') and hasattr(content.text, 'value'):
if hasattr(content.text, 'annotations') and content.text.annotations:
skipped_annotations += 1
continue
chunk_count += 1
yield content.text.value
# Handle tool calls
elif event.event == "thread.run.requires_action":
logger.info("Assistant requires action (tool calls)")
run_id = event.data.id
tool_calls = event.data.required_action.submit_tool_outputs.tool_calls
tool_outputs = []
for tool_call in tool_calls:
logger.info(f"Processing tool call: {tool_call.function.name}")
tool_arguments = json.loads(tool_call.function.arguments)
tool_output = self.execute_tool_call(
tool_call.function.name,
tool_arguments
)
tool_outputs.append({
"tool_call_id": tool_call.id,
"output": tool_output
})
# Submit tool outputs and continue streaming
logger.info(f"Submitting {len(tool_outputs)} tool outputs")
with self.client.beta.threads.runs.submit_tool_outputs_stream(
thread_id=thread_id,
run_id=run_id,
tool_outputs=tool_outputs
) as tool_stream:
for tool_event in tool_stream:
if tool_event.event == "thread.message.delta":
for content in tool_event.data.delta.content:
if hasattr(content, 'text') and hasattr(content.text, 'value'):
if hasattr(content.text, 'annotations') and content.text.annotations:
skipped_annotations += 1
continue
chunk_count += 1
yield content.text.value
logger.info(f"Stream completed. Chunks received: {chunk_count}, Annotations skipped: {skipped_annotations}")
except Exception as e:
logger.error(f"Error in generate_stream: {str(e)}", exc_info=True)
yield f"Error: {str(e)}"
def clear_thread(self):
"""Clear the current thread by creating a new one."""
if self.thread:
logger.info(f"Clearing thread: {self.thread.id}")
else:
logger.info("No active thread to clear")
self.thread = None
logger.info("Thread cleared. New thread will be created on next message")
|