ThreeAI-Demo / generate_service.py
VolodymyrHula's picture
Fixed get customer tool call. Added rest of ther files
ad02e51
import os
import logging
import json
from openai import OpenAI
from dotenv import load_dotenv
from data_service import DataService
# Load environment variables from .env file
load_dotenv(override=True)
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
class OpenAIService:
def __init__(self, api_key=None, assistant_id=None, data_dir="data"):
"""
Initialize OpenAI service with Assistant API.
Args:
api_key: OpenAI API key (defaults to OPENAI_API_KEY env var)
assistant_id: OpenAI Assistant ID (defaults to ASSISTANT_ID env var)
data_dir: Path to data directory for DataService (default: "data")
"""
logger.info("Initializing OpenAI service...")
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
self.assistant_id = assistant_id or os.getenv("ASSISTANT_ID")
if not self.api_key:
logger.error("OpenAI API key not found in environment variables")
raise ValueError("OpenAI API key is required. Set OPENAI_API_KEY environment variable.")
if not self.assistant_id:
logger.error("Assistant ID not found in environment variables")
raise ValueError("Assistant ID is required. Set ASSISTANT_ID environment variable.")
logger.info(f"API key loaded (length: {len(self.api_key)})")
logger.info(f"Assistant ID: {self.assistant_id}")
self.client = OpenAI(api_key=self.api_key)
self.thread = None
self.data_service = DataService(data_dir)
logger.info("OpenAI service initialized successfully")
def create_thread(self):
"""Create a new conversation thread."""
logger.info("Creating new conversation thread...")
self.thread = self.client.beta.threads.create()
logger.info(f"Thread created successfully: {self.thread.id}")
return self.thread.id
def get_or_create_thread(self):
"""Get existing thread or create a new one."""
if not self.thread:
logger.info("No existing thread found, creating new thread")
self.create_thread()
else:
logger.info(f"Using existing thread: {self.thread.id}")
return self.thread.id
def execute_tool_call(self, tool_name, tool_arguments):
"""
Execute a tool call and return the result.
Args:
tool_name: Name of the tool to execute
tool_arguments: Dictionary of arguments for the tool
Returns:
str: JSON string with the tool execution result in format { success: bool, result/error: data }
"""
logger.info(f"Executing tool: {tool_name} with arguments: {tool_arguments}")
try:
if tool_name == "get_real_time_commissions":
result = self.data_service.get_data("GetRealTimeCommissions.json")
if result is None:
return json.dumps({"success": False, "error": f"Failed to load data from GetRealTimeCommissions.json"})
return json.dumps({"success": True, "result": result})
elif tool_name == "get_volumes":
result = self.data_service.get_data("GetVolumes.json")
if result is None:
return json.dumps({"success": False, "error": f"Failed to load data from GetVolumes.json"})
return json.dumps({"success": True, "result": result})
elif tool_name == "get_customer":
result = self.data_service.get_data("GetCustomers.json")
if result is None:
return json.dumps({"success": False, "error": f"Failed to load data from GetCustomers.json"})
return json.dumps({"success": True, "result": result})
else:
logger.warning(f"Unknown tool: {tool_name}")
return json.dumps({"success": False, "error": f"Unknown tool: {tool_name}"})
except Exception as e:
logger.error(f"Error executing tool {tool_name}: {str(e)}", exc_info=True)
return json.dumps({"success": False, "error": str(e)})
def generate_stream(self, message):
"""
Generate response from assistant with streaming and tool handling.
Args:
message: User's message string
Yields:
str: Chunks of the response as they arrive
"""
try:
logger.info(f"Processing message: {message[:50]}...")
thread_id = self.get_or_create_thread()
# Add user message to thread
logger.info(f"Adding message to thread {thread_id}")
self.client.beta.threads.messages.create(
thread_id=thread_id,
role="user",
content=message
)
logger.info("Message added successfully")
# Stream the assistant's response
logger.info("Starting assistant response stream...")
chunk_count = 0
skipped_annotations = 0
with self.client.beta.threads.runs.stream(
thread_id=thread_id,
assistant_id=self.assistant_id
) as stream:
for event in stream:
# Handle text streaming
if event.event == "thread.message.delta":
for content in event.data.delta.content:
if hasattr(content, 'text') and hasattr(content.text, 'value'):
if hasattr(content.text, 'annotations') and content.text.annotations:
skipped_annotations += 1
continue
chunk_count += 1
yield content.text.value
# Handle tool calls
elif event.event == "thread.run.requires_action":
logger.info("Assistant requires action (tool calls)")
run_id = event.data.id
tool_calls = event.data.required_action.submit_tool_outputs.tool_calls
tool_outputs = []
for tool_call in tool_calls:
logger.info(f"Processing tool call: {tool_call.function.name}")
tool_arguments = json.loads(tool_call.function.arguments)
tool_output = self.execute_tool_call(
tool_call.function.name,
tool_arguments
)
tool_outputs.append({
"tool_call_id": tool_call.id,
"output": tool_output
})
# Submit tool outputs and continue streaming
logger.info(f"Submitting {len(tool_outputs)} tool outputs")
with self.client.beta.threads.runs.submit_tool_outputs_stream(
thread_id=thread_id,
run_id=run_id,
tool_outputs=tool_outputs
) as tool_stream:
for tool_event in tool_stream:
if tool_event.event == "thread.message.delta":
for content in tool_event.data.delta.content:
if hasattr(content, 'text') and hasattr(content.text, 'value'):
if hasattr(content.text, 'annotations') and content.text.annotations:
skipped_annotations += 1
continue
chunk_count += 1
yield content.text.value
logger.info(f"Stream completed. Chunks received: {chunk_count}, Annotations skipped: {skipped_annotations}")
except Exception as e:
logger.error(f"Error in generate_stream: {str(e)}", exc_info=True)
yield f"Error: {str(e)}"
def clear_thread(self):
"""Clear the current thread by creating a new one."""
if self.thread:
logger.info(f"Clearing thread: {self.thread.id}")
else:
logger.info("No active thread to clear")
self.thread = None
logger.info("Thread cleared. New thread will be created on next message")