IsaacGHX
update
d12a6df
import os
import json
import base64
import platformdirs
from typing import List, Union, Dict, Any, TypeVar
from openai import AzureOpenAI
from tenacity import (
retry,
stop_after_attempt,
wait_random_exponential,
)
from agentflow.models.formatters import QueryAnalysis
from .base import EngineLM, CachedEngine
from .engine_utils import get_image_type_from_bytes
T = TypeVar('T', bound='BaseModel')
def validate_structured_output_model(model_string: str) -> bool:
"""Check if the model supports structured outputs."""
# Azure OpenAI models that support structured outputs
return any(x in model_string.lower() for x in ["gpt-4"])
def validate_chat_model(model_string: str) -> bool:
"""Check if the model is a chat model."""
return any(x in model_string.lower() for x in ["gpt"])
def validate_reasoning_model(model_string: str) -> bool:
"""Check if the model is a reasoning model."""
# Azure OpenAI doesn't have specific reasoning models like OpenAI
return False
def validate_pro_reasoning_model(model_string: str) -> bool:
"""Check if the model is a pro reasoning model."""
# Azure OpenAI doesn't have pro reasoning models
return False
class ChatAzureOpenAI(EngineLM, CachedEngine):
"""
Azure OpenAI API implementation of the EngineLM interface.
"""
DEFAULT_SYSTEM_PROMPT = "You are a helpful, creative, and smart assistant."
def __init__(
self,
model_string: str = "gpt-4",
use_cache: bool = False,
system_prompt: str = DEFAULT_SYSTEM_PROMPT,
is_multimodal: bool = False,
**kwargs
):
"""
Initialize the Azure OpenAI engine.
Args:
model_string: The name of the Azure OpenAI deployment
use_cache: Whether to use caching
system_prompt: The system prompt to use
is_multimodal: Whether to enable multimodal capabilities
**kwargs: Additional arguments to pass to the AzureOpenAI client
"""
self.model_string = model_string
self.use_cache = use_cache
self.system_prompt = system_prompt
self.is_multimodal = is_multimodal
# Set model capabilities
self.support_structured_output = validate_structured_output_model(self.model_string)
self.is_chat_model = validate_chat_model(self.model_string)
self.is_reasoning_model = validate_reasoning_model(self.model_string)
self.is_pro_reasoning_model = validate_pro_reasoning_model(self.model_string)
# Set up caching if enabled
if self.use_cache:
root = platformdirs.user_cache_dir("agentflow")
cache_path = os.path.join(root, f"cache_azure_openai_{model_string}.db")
self.image_cache_dir = os.path.join(root, "image_cache")
os.makedirs(self.image_cache_dir, exist_ok=True)
super().__init__(cache_path=cache_path)
# Validate required environment variables
if not os.getenv("AZURE_OPENAI_API_KEY"):
raise ValueError("Please set the AZURE_OPENAI_API_KEY environment variable.")
if not os.getenv("AZURE_OPENAI_ENDPOINT"):
raise ValueError("Please set the AZURE_OPENAI_ENDPOINT environment variable.")
# Initialize Azure OpenAI client
self.client = AzureOpenAI(
api_key=os.getenv("AZURE_OPENAI_API_KEY"),
api_version=os.getenv("AZURE_OPENAI_API_VERSION", "2024-12-01-preview"),
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
)
# Set default kwargs
self.default_kwargs = kwargs
def __call__(self, prompt, **kwargs):
"""
Handle direct calls to the instance (e.g., model(prompt)).
Forwards the call to the generate method.
"""
return self.generate(prompt, **kwargs)
def _format_content(self, content: List[Union[str, bytes]]) -> List[Dict[str, Any]]:
"""Format content for the OpenAI API."""
formatted_content = []
for item in content:
if isinstance(item, str):
formatted_content.append({"type": "text", "text": item})
elif isinstance(item, bytes):
# For images, encode as base64
image_type = get_image_type_from_bytes(item)
if image_type:
base64_image = base64.b64encode(item).decode('utf-8')
formatted_content.append({
"type": "image_url",
"image_url": {
"url": f"data:image/{image_type};base64,{base64_image}",
"detail": "auto"
}
})
elif isinstance(item, dict) and "type" in item:
# Already formatted content
formatted_content.append(item)
return formatted_content
@retry(
wait=wait_random_exponential(min=1, max=5),
stop=stop_after_attempt(5),
)
def generate(self, content: Union[str, List[Union[str, bytes]]], system_prompt=None, **kwargs):
try:
if isinstance(content, str):
return self._generate_text(content, system_prompt=system_prompt, **kwargs)
elif isinstance(content, list):
if not self.is_multimodal:
raise NotImplementedError(f"Multimodal generation is not supported for {self.model_string}.")
return self._generate_multimodal(content, system_prompt=system_prompt, **kwargs)
except Exception as e:
print(f"Error in generate method: {str(e)}")
print(f"Error type: {type(e).__name__}")
print(f"Error details: {e.args}")
return {
"error": type(e).__name__,
"message": str(e),
"details": getattr(e, 'args', None)
}
def _generate_text(
self,
prompt: str,
system_prompt: str = None,
temperature: float = 0,
max_tokens: int = 4000,
top_p: float = 0.99,
response_format: dict = None,
**kwargs,
) -> str:
"""
Generate a response from the Azure OpenAI API.
"""
sys_prompt_arg = system_prompt if system_prompt else self.system_prompt
if self.use_cache:
cache_key = sys_prompt_arg + prompt
cache_or_none = self._check_cache(cache_key)
if cache_or_none is not None:
return cache_or_none
# Chat models with structured output format
if self.is_chat_model and self.support_structured_output and response_format is not None:
response = self.client.beta.chat.completions.parse(
model=self.model_string,
messages=[
{"role": "system", "content": sys_prompt_arg},
{"role": "user", "content": prompt},
],
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
response_format=response_format,
frequency_penalty=0,
presence_penalty=0,
stop=None
)
response = response.choices[0].message.parsed
# Chat models without structured outputs
elif self.is_chat_model and (not self.support_structured_output or response_format is None):
response = self.client.chat.completions.create(
model=self.model_string,
messages=[
{"role": "system", "content": sys_prompt_arg},
{"role": "user", "content": prompt},
],
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
frequency_penalty=0,
presence_penalty=0,
stop=None
)
response = response.choices[0].message.content
# Reasoning models: currently only supports base response
elif self.is_reasoning_model:
print(f"Using reasoning model: {self.model_string}")
response = self.client.chat.completions.create(
model=self.model_string,
messages=[
{"role": "user", "content": prompt},
],
max_completion_tokens=max_tokens,
reasoning_effort="medium",
frequency_penalty=0,
presence_penalty=0,
stop=None
)
# Workaround for handling length finish reason
if hasattr(response.choices[0], 'finish_reason') and response.choices[0].finish_reason == "length":
response = "Token limit exceeded"
else:
response = response.choices[0].message.content
# Fallback for other model types
else:
response = self.client.chat.completions.create(
model=self.model_string,
messages=[
{"role": "system", "content": sys_prompt_arg},
{"role": "user", "content": prompt},
],
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
frequency_penalty=0,
presence_penalty=0,
stop=None
)
response = response.choices[0].message.content
# Cache the response if caching is enabled
if self.use_cache:
self._add_to_cache(cache_key, response)
return response
def _generate_multimodal(
self,
content: List[Union[str, bytes]],
system_prompt: str = None,
temperature: float = 0,
max_tokens: int = 4000,
top_p: float = 0.99,
response_format: dict = None,
**kwargs,
) -> str:
"""
Generate a response from multiple input types (text and images).
"""
if not self.is_multimodal:
raise ValueError("Multimodal input is not supported by this model.")
sys_prompt_arg = system_prompt if system_prompt else self.system_prompt
formatted_content = self._format_content(content)
if self.use_cache:
cache_key = sys_prompt_arg + json.dumps(formatted_content)
cache_or_none = self._check_cache(cache_key)
if cache_or_none is not None:
return cache_or_none
messages = [
{"role": "system", "content": sys_prompt_arg},
{"role": "user", "content": formatted_content},
]
# Chat models with structured output format
if self.is_chat_model and self.support_structured_output and response_format is not None:
response = self.client.beta.chat.completions.parse(
model=self.model_string,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
response_format=response_format,
frequency_penalty=0,
presence_penalty=0,
stop=None
)
response_content = response.choices[0].message.parsed
# Standard chat completion
elif self.is_chat_model and (not self.support_structured_output or response_format is None):
response = self.client.chat.completions.create(
model=self.model_string,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
frequency_penalty=0,
presence_penalty=0,
stop=None
)
response_content = response.choices[0].message.content
# Reasoning models: currently only supports base response
elif self.is_reasoning_model:
print(f"Using reasoning model: {self.model_string}")
response = self.client.chat.completions.create(
model=self.model_string,
messages=[
{"role": "user", "content": formatted_content},
],
max_completion_tokens=max_tokens,
reasoning_effort="medium",
frequency_penalty=0,
presence_penalty=0,
stop=None
)
# Workaround for handling length finish reason
if hasattr(response.choices[0], 'finish_reason') and response.choices[0].finish_reason == "length":
response_content = "Token limit exceeded"
else:
response_content = response.choices[0].message.content
# Fallback for other model types
else:
response = self.client.chat.completions.create(
model=self.model_string,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
frequency_penalty=0,
presence_penalty=0,
stop=None
)
response_content = response.choices[0].message.content
# Cache the response if caching is enabled
if self.use_cache:
self._add_to_cache(cache_key, response_content)
return response_content