| """ |
| Centralized AI Manager for multiple providers. |
| Supports Gemini, Nebius Token Factory, and other OpenAI-compatible providers. |
| """ |
|
|
| import os |
| import json |
| import logging |
| from typing import Dict, Any, Optional, List |
| from enum import Enum |
| from dotenv import load_dotenv |
|
|
| |
| load_dotenv() |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class AIProvider(Enum): |
| """Supported AI providers.""" |
| GEMINI = "gemini" |
| NEBIUS = "nebius" |
| OPENAI = "openai" |
|
|
|
|
| class AIManager: |
| """ |
| Centralized manager for AI API calls across different providers. |
| Provides a unified interface regardless of the underlying provider. |
| """ |
| |
| |
| DEFAULT_PROVIDER = "gemini" |
| DEFAULT_GEMINI_MODEL = "gemini-2.5-flash" |
| DEFAULT_NEBIUS_MODEL = "zai-org/GLM-4.5" |
| DEFAULT_OPENAI_MODEL = "gpt-4" |
| |
| |
| TEMPERATURE_PRECISE = 0.0 |
| TEMPERATURE_LOW = 0.1 |
| TEMPERATURE_MEDIUM = 0.2 |
| TEMPERATURE_HIGH = 0.7 |
| |
| |
| MAX_OUTPUT_TOKENS_SMALL = 8192 |
| MAX_OUTPUT_TOKENS_MEDIUM = 16384 |
| MAX_OUTPUT_TOKENS_LARGE = 32768 |
| |
| |
| MAX_RETRIES = 3 |
| RETRY_DELAY = 1.0 |
| |
| def __init__(self, provider: Optional[str] = None, model: Optional[str] = None): |
| """ |
| Initialize AI Manager. |
| |
| Args: |
| provider: AI provider to use (gemini, nebius, openai). |
| If None, reads from AI_PROVIDER env var or uses default. |
| model: Model name to use. If None, reads from provider-specific env var. |
| """ |
| |
| self.provider_name = ( |
| provider or |
| os.getenv("AI_PROVIDER", self.DEFAULT_PROVIDER) |
| ).lower() |
| |
| try: |
| self.provider = AIProvider(self.provider_name) |
| except ValueError: |
| logger.warning( |
| f"Unknown provider '{self.provider_name}', falling back to Gemini" |
| ) |
| self.provider = AIProvider.GEMINI |
| self.provider_name = "gemini" |
| |
| |
| if self.provider == AIProvider.GEMINI: |
| self._init_gemini(model) |
| elif self.provider == AIProvider.NEBIUS: |
| self._init_nebius(model) |
| elif self.provider == AIProvider.OPENAI: |
| self._init_openai(model) |
| |
| logger.info( |
| f"AIManager initialized with provider: {self.provider_name}, " |
| f"model: {self.model_name}" |
| ) |
| |
| def _init_gemini(self, model: Optional[str] = None): |
| """Initialize Gemini provider.""" |
| from google import genai |
| |
| api_key = os.getenv("GEMINI_API_KEY") |
| if not api_key: |
| raise ValueError( |
| "GEMINI_API_KEY not found in environment variables. " |
| "Please set it in your .env file." |
| ) |
| |
| self.model_name = ( |
| model or |
| os.getenv("GEMINI_MODEL", self.DEFAULT_GEMINI_MODEL) |
| ) |
| |
| self.client = genai.Client(api_key=api_key) |
| self.provider_type = "gemini" |
| |
| def _init_nebius(self, model: Optional[str] = None): |
| """Initialize Nebius Token Factory provider (OpenAI-compatible).""" |
| from openai import OpenAI |
| |
| api_key = os.getenv("NEBIUS_API_KEY") |
| if not api_key: |
| raise ValueError( |
| "NEBIUS_API_KEY not found in environment variables. " |
| "Please set it in your .env file." |
| ) |
| |
| self.model_name = ( |
| model or |
| os.getenv("NEBIUS_MODEL", self.DEFAULT_NEBIUS_MODEL) |
| ) |
| |
| self.client = OpenAI( |
| base_url="https://api.tokenfactory.nebius.com/v1/", |
| api_key=api_key |
| ) |
| self.provider_type = "openai_compatible" |
| |
| def _init_openai(self, model: Optional[str] = None): |
| """Initialize OpenAI provider.""" |
| from openai import OpenAI |
| |
| api_key = os.getenv("OPENAI_API_KEY") |
| if not api_key: |
| raise ValueError( |
| "OPENAI_API_KEY not found in environment variables. " |
| "Please set it in your .env file." |
| ) |
| |
| self.model_name = ( |
| model or |
| os.getenv("OPENAI_MODEL", self.DEFAULT_OPENAI_MODEL) |
| ) |
| |
| self.client = OpenAI(api_key=api_key) |
| self.provider_type = "openai_compatible" |
| |
| def generate_content( |
| self, |
| prompt: str, |
| temperature: float = TEMPERATURE_LOW, |
| max_tokens: int = MAX_OUTPUT_TOKENS_MEDIUM, |
| response_format: Optional[str] = None, |
| response_schema: Optional[Dict[str, Any]] = None, |
| system_prompt: Optional[str] = None |
| ) -> str: |
| """ |
| Generate content using the configured AI provider. |
| |
| Args: |
| prompt: The prompt to send to the AI |
| temperature: Temperature setting (0.0-1.0) |
| max_tokens: Maximum output tokens |
| response_format: Response format ("json" or None) |
| response_schema: JSON schema for structured responses (Gemini format) |
| system_prompt: Optional system prompt (for OpenAI-compatible providers) |
| |
| Returns: |
| Generated text content |
| """ |
| if self.provider_type == "gemini": |
| return self._generate_gemini( |
| prompt, temperature, max_tokens, |
| response_format, response_schema |
| ) |
| else: |
| return self._generate_openai_compatible( |
| prompt, temperature, max_tokens, |
| response_format, system_prompt |
| ) |
| |
| def _generate_gemini( |
| self, |
| prompt: str, |
| temperature: float, |
| max_tokens: int, |
| response_format: Optional[str], |
| response_schema: Optional[Dict[str, Any]] |
| ) -> str: |
| """Generate content using Gemini API.""" |
| config = { |
| "temperature": temperature, |
| "max_output_tokens": max_tokens, |
| "top_p": 0.95, |
| } |
| |
| |
| if response_schema: |
| config["response_mime_type"] = "application/json" |
| config["response_schema"] = response_schema |
| elif response_format == "json": |
| config["response_mime_type"] = "application/json" |
| |
| response = self.client.models.generate_content( |
| model=self.model_name, |
| contents=prompt, |
| config=config |
| ) |
| |
| return response.text |
| |
| def _generate_openai_compatible( |
| self, |
| prompt: str, |
| temperature: float, |
| max_tokens: int, |
| response_format: Optional[str], |
| system_prompt: Optional[str] |
| ) -> str: |
| """Generate content using OpenAI-compatible API.""" |
| messages = [] |
| |
| |
| if system_prompt: |
| messages.append({"role": "system", "content": system_prompt}) |
| |
| messages.append({"role": "user", "content": prompt}) |
| |
| kwargs = { |
| "model": self.model_name, |
| "messages": messages, |
| "temperature": temperature, |
| "max_tokens": max_tokens, |
| } |
| |
| |
| if response_format == "json": |
| kwargs["response_format"] = {"type": "json_object"} |
| |
| response = self.client.chat.completions.create(**kwargs) |
| |
| return response.choices[0].message.content |
| |
| def get_base_config( |
| self, |
| temperature: float = TEMPERATURE_LOW, |
| max_tokens: int = MAX_OUTPUT_TOKENS_MEDIUM |
| ) -> Dict[str, Any]: |
| """ |
| Get base configuration for AI calls. |
| |
| Args: |
| temperature: Temperature setting (0.0-1.0) |
| max_tokens: Maximum output tokens |
| |
| Returns: |
| Configuration dictionary |
| """ |
| return { |
| "temperature": temperature, |
| "max_tokens": max_tokens, |
| } |
| |
| def get_json_config( |
| self, |
| schema: Optional[Dict[str, Any]] = None, |
| temperature: float = TEMPERATURE_PRECISE, |
| max_tokens: int = MAX_OUTPUT_TOKENS_MEDIUM |
| ) -> Dict[str, Any]: |
| """ |
| Get configuration for JSON schema-enforced responses. |
| |
| Args: |
| schema: JSON schema dictionary (Gemini format) |
| temperature: Temperature setting (default: 0.0 for precision) |
| max_tokens: Maximum output tokens |
| |
| Returns: |
| Configuration dictionary |
| """ |
| config = self.get_base_config(temperature, max_tokens) |
| config["response_format"] = "json" |
| |
| if schema and self.provider_type == "gemini": |
| config["response_schema"] = schema |
| |
| return config |
| |
| @classmethod |
| def validate_config(cls) -> bool: |
| """ |
| Validate that required configuration is present. |
| |
| Returns: |
| True if configuration is valid |
| |
| Raises: |
| ValueError: If required configuration is missing |
| """ |
| provider = os.getenv("AI_PROVIDER", cls.DEFAULT_PROVIDER).lower() |
| |
| if provider == "gemini": |
| if not os.getenv("GEMINI_API_KEY"): |
| raise ValueError( |
| "GEMINI_API_KEY not found in environment variables. " |
| "Please set it in your .env file." |
| ) |
| elif provider == "nebius": |
| if not os.getenv("NEBIUS_API_KEY"): |
| raise ValueError( |
| "NEBIUS_API_KEY not found in environment variables. " |
| "Please set it in your .env file." |
| ) |
| elif provider == "openai": |
| if not os.getenv("OPENAI_API_KEY"): |
| raise ValueError( |
| "OPENAI_API_KEY not found in environment variables. " |
| "Please set it in your .env file." |
| ) |
| |
| return True |
|
|