from __future__ import annotations import os import json from ...typing import Messages, AsyncResult, MediaListType from ...errors import MissingAuthError, ModelNotFoundError from ...requests import StreamSession, FormData, raise_for_status from ...image import get_width_height, to_bytes from ...image.copy_images import save_response_media from ..template import OpenaiTemplate from ..helper import format_media_prompt class Azure(OpenaiTemplate): label = "Azure ☁️" url = "https://ai.azure.com" api_base = "https://host.g4f.dev/api/Azure" working = True needs_auth = True models_needs_auth = True active_by_default = True login_url = "https://discord.gg/qXA4Wf4Fsm" routes: dict[str, str] = {} audio_models = ["gpt-4o-mini-audio-preview"] vision_models = ["gpt-4.1", "o4-mini", "model-router", "flux.1-kontext-pro"] image_models = ["flux-1.1-pro", "flux.1-kontext-pro"] model_aliases = { "flux-kontext": "flux.1-kontext-pro" } model_extra_body = { "gpt-4o-mini-audio-preview": { "audio": { "voice": "alloy", "format": "mp3" }, "modalities": ["text", "audio"], } } api_keys: dict[str, str] = {} failed: dict[str, int] = {} @classmethod def get_models(cls, api_key: str = None, **kwargs) -> list[str]: api_keys = os.environ.get("AZURE_API_KEYS") if api_keys: try: cls.api_keys = json.loads(api_keys) except json.JSONDecodeError: raise ValueError(f"Invalid AZURE_API_KEYS environment variable") routes = os.environ.get("AZURE_ROUTES") if routes: try: routes = json.loads(routes) except json.JSONDecodeError: raise ValueError(f"Invalid AZURE_ROUTES environment variable format: {routes}") cls.routes = routes if cls.routes: if cls.live == 0 and cls.api_keys: cls.live += 1 return list(cls.routes.keys()) return super().get_models(api_key=api_key, **kwargs) @classmethod async def create_async_generator( cls, model: str, messages: Messages, stream: bool = True, media: MediaListType = None, api_key: str = None, api_endpoint: str = None, **kwargs ) -> AsyncResult: if not model: model = os.environ.get("AZURE_DEFAULT_MODEL", cls.default_model) if model in cls.model_aliases: model = cls.model_aliases[model] if not api_endpoint: if not cls.routes: cls.get_models() api_endpoint = cls.routes.get(model) if cls.routes and not api_endpoint: raise ModelNotFoundError(f"No API endpoint found for model: {model}") if not api_endpoint: api_endpoint = os.environ.get("AZURE_API_ENDPOINT") if cls.api_keys: api_key = cls.api_keys.get(model, cls.api_keys.get("default")) if not api_key: raise ValueError(f"API key is required for Azure provider. Ask for API key in the {cls.login_url} Discord server.") if api_endpoint and "/images/" in api_endpoint: prompt = format_media_prompt(messages, kwargs.get("prompt")) width, height = get_width_height(kwargs.get("aspect_ratio", "1:1"), kwargs.get("width"), kwargs.get("height")) output_format = kwargs.get("output_format", "png") form = None data = None if media: form = FormData() form.add_field("prompt", prompt) form.add_field("width", str(width)) form.add_field("height", str(height)) output_format = "png" for i in range(len(media)): if media[i][1] is None and isinstance(media[i][0], str): media[i] = media[i][0], os.path.basename(media[i][0]) media[i] = (to_bytes(media[i][0]), media[i][1]) for image, image_name in media: form.add_field(f"image", image, filename=image_name) else: api_endpoint = api_endpoint.replace("/edits", "/generations") data = { "prompt": prompt, "n": 1, "width": width, "height": height, "output_format": output_format, } async with StreamSession(proxy=kwargs.get("proxy"), headers={ "Authorization": f"Bearer {api_key}", "x-ms-model-mesh-model-name": model, }) as session: async with session.post(api_endpoint, data=form, json=data) as response: data = await response.json() await raise_for_status(response, data) async for chunk in save_response_media( data["data"][0]["b64_json"], prompt, content_type=f"image/{output_format.replace('jpg', 'jpeg')}" ): yield chunk return if model in cls.model_extra_body: for key, value in cls.model_extra_body[model].items(): kwargs.setdefault(key, value) stream = False if stream: kwargs.setdefault("stream_options", {"include_usage": True}) if cls.failed.get(api_key, 0) >= 3: raise MissingAuthError(f"API key has failed too many times.") try: async for chunk in super().create_async_generator( model=model, messages=messages, stream=stream, media=media, api_key=api_key, api_endpoint=api_endpoint, **kwargs ): yield chunk except MissingAuthError as e: cls.failed[api_key] = cls.failed.get(api_key, 0) + 1 raise MissingAuthError(f"{e}. Ask for help in the {cls.login_url} Discord server.") from e