auranexus / core /ai_gateway.py
Ahmed766's picture
Upload core/ai_gateway.py with huggingface_hub
0e0c403 verified
from abc import ABC, abstractmethod
from typing import Dict, Any, Optional
from enum import Enum
import asyncio
import logging
import os
import requests
import json
logger = logging.getLogger(__name__)
class ModelType(Enum):
TEXT_GENERATION = "text_generation"
IMAGE_GENERATION = "image_generation"
EMBEDDING = "embedding"
class ModelProvider(Enum):
LOCAL_LLAMA = "local_llama"
LOCAL_MISTRAL = "local_mistral"
API_OPENAI = "api_openai"
LOCAL_STABLE_DIFFUSION = "local_stable_diffusion"
class BaseModel(ABC):
@abstractmethod
async def generate(self, prompt: str, **kwargs) -> str:
pass
class MockLlamaModel(BaseModel):
def __init__(self, model_path: str):
self.model_path = model_path
logger.info(f"Initialized mock Llama model with path: {model_path}")
async def generate(self, prompt: str, max_length: int = 512, **kwargs) -> str:
# Simulate model response for demonstration
return f"Mock response to: {prompt[:50]}... [Truncated for demo]"
class MockMistralModel(BaseModel):
def __init__(self, model_path: str):
self.model_path = model_path
logger.info(f"Initialized mock Mistral model with path: {model_path}")
async def generate(self, prompt: str, max_length: int = 512, **kwargs) -> str:
# Simulate model response for demonstration
return f"Mistral-style response to: {prompt[:50]}... [Truncated for demo]"
class MockStableDiffusionModel(BaseModel):
def __init__(self):
logger.info("Initialized mock Stable Diffusion model")
async def generate(self, prompt: str, **kwargs) -> str:
# Simulate image generation for demonstration
return f"Mock image generated for prompt: {prompt[:50]}... [Truncated for demo]"
class AIGateway:
def __init__(self):
self.models = {}
self._initialize_models()
def _initialize_models(self):
"""Initialize available models"""
try:
# In a real implementation, we would load actual models
# For this demo, we'll use mock implementations
self.models[ModelProvider.LOCAL_LLAMA] = MockLlamaModel("llama-model-path")
self.models[ModelProvider.LOCAL_MISTRAL] = MockMistralModel("mistral-model-path")
self.models[ModelProvider.LOCAL_STABLE_DIFFUSION] = MockStableDiffusionModel()
logger.info("AI Gateway initialized with mock models")
except Exception as e:
logger.error(f"Error initializing models: {e}")
async def generate_text(
self,
prompt: str,
provider: ModelProvider = ModelProvider.LOCAL_LLAMA,
**kwargs
) -> str:
"""
Generate text using the specified provider
"""
if provider not in self.models:
raise ValueError(f"Model provider {provider} not available")
model = self.models[provider]
logger.info(f"Generating text using {provider.value}")
try:
result = await model.generate(prompt, **kwargs)
logger.info(f"Generated {len(result)} characters")
return result
except Exception as e:
logger.error(f"Error generating text: {e}")
raise
async def generate_image(
self,
prompt: str,
**kwargs
) -> str:
"""
Generate image using the image generation model
"""
model = self.models[ModelProvider.LOCAL_STABLE_DIFFUSION]
logger.info("Generating image")
try:
result = await model.generate(prompt, **kwargs)
logger.info("Image generated successfully")
return result
except Exception as e:
logger.error(f"Error generating image: {e}")
raise
async def route_request(
self,
prompt: str,
preferred_provider: Optional[ModelProvider] = None,
fallback_providers: Optional[list] = None
) -> str:
"""
Route request with fallback mechanism
"""
providers_to_try = []
if preferred_provider:
providers_to_try.append(preferred_provider)
if fallback_providers:
providers_to_try.extend(fallback_providers)
else:
# Default fallback order
providers_to_try.extend([
ModelProvider.LOCAL_LLAMA,
ModelProvider.LOCAL_MISTRAL
])
for provider in providers_to_try:
try:
return await self.generate_text(prompt, provider)
except Exception as e:
logger.warning(f"Provider {provider.value} failed: {e}")
continue
raise RuntimeError("All providers failed")