| from abc import ABC, abstractmethod |
| from typing import Dict, Any, Optional |
| from enum import Enum |
| import asyncio |
| import logging |
| import os |
| import requests |
| import json |
|
|
| logger = logging.getLogger(__name__) |
|
|
| class ModelType(Enum): |
| TEXT_GENERATION = "text_generation" |
| IMAGE_GENERATION = "image_generation" |
| EMBEDDING = "embedding" |
|
|
| class ModelProvider(Enum): |
| LOCAL_LLAMA = "local_llama" |
| LOCAL_MISTRAL = "local_mistral" |
| API_OPENAI = "api_openai" |
| LOCAL_STABLE_DIFFUSION = "local_stable_diffusion" |
|
|
| class BaseModel(ABC): |
| @abstractmethod |
| async def generate(self, prompt: str, **kwargs) -> str: |
| pass |
|
|
| class MockLlamaModel(BaseModel): |
| def __init__(self, model_path: str): |
| self.model_path = model_path |
| logger.info(f"Initialized mock Llama model with path: {model_path}") |
|
|
| async def generate(self, prompt: str, max_length: int = 512, **kwargs) -> str: |
| |
| return f"Mock response to: {prompt[:50]}... [Truncated for demo]" |
|
|
| class MockMistralModel(BaseModel): |
| def __init__(self, model_path: str): |
| self.model_path = model_path |
| logger.info(f"Initialized mock Mistral model with path: {model_path}") |
|
|
| async def generate(self, prompt: str, max_length: int = 512, **kwargs) -> str: |
| |
| return f"Mistral-style response to: {prompt[:50]}... [Truncated for demo]" |
|
|
| class MockStableDiffusionModel(BaseModel): |
| def __init__(self): |
| logger.info("Initialized mock Stable Diffusion model") |
|
|
| async def generate(self, prompt: str, **kwargs) -> str: |
| |
| return f"Mock image generated for prompt: {prompt[:50]}... [Truncated for demo]" |
|
|
| class AIGateway: |
| def __init__(self): |
| self.models = {} |
| self._initialize_models() |
| |
| def _initialize_models(self): |
| """Initialize available models""" |
| try: |
| |
| |
| self.models[ModelProvider.LOCAL_LLAMA] = MockLlamaModel("llama-model-path") |
| self.models[ModelProvider.LOCAL_MISTRAL] = MockMistralModel("mistral-model-path") |
| self.models[ModelProvider.LOCAL_STABLE_DIFFUSION] = MockStableDiffusionModel() |
| logger.info("AI Gateway initialized with mock models") |
| except Exception as e: |
| logger.error(f"Error initializing models: {e}") |
| |
| async def generate_text( |
| self, |
| prompt: str, |
| provider: ModelProvider = ModelProvider.LOCAL_LLAMA, |
| **kwargs |
| ) -> str: |
| """ |
| Generate text using the specified provider |
| """ |
| if provider not in self.models: |
| raise ValueError(f"Model provider {provider} not available") |
| |
| model = self.models[provider] |
| logger.info(f"Generating text using {provider.value}") |
| |
| try: |
| result = await model.generate(prompt, **kwargs) |
| logger.info(f"Generated {len(result)} characters") |
| return result |
| except Exception as e: |
| logger.error(f"Error generating text: {e}") |
| raise |
| |
| async def generate_image( |
| self, |
| prompt: str, |
| **kwargs |
| ) -> str: |
| """ |
| Generate image using the image generation model |
| """ |
| model = self.models[ModelProvider.LOCAL_STABLE_DIFFUSION] |
| logger.info("Generating image") |
| |
| try: |
| result = await model.generate(prompt, **kwargs) |
| logger.info("Image generated successfully") |
| return result |
| except Exception as e: |
| logger.error(f"Error generating image: {e}") |
| raise |
| |
| async def route_request( |
| self, |
| prompt: str, |
| preferred_provider: Optional[ModelProvider] = None, |
| fallback_providers: Optional[list] = None |
| ) -> str: |
| """ |
| Route request with fallback mechanism |
| """ |
| providers_to_try = [] |
| |
| if preferred_provider: |
| providers_to_try.append(preferred_provider) |
| |
| if fallback_providers: |
| providers_to_try.extend(fallback_providers) |
| else: |
| |
| providers_to_try.extend([ |
| ModelProvider.LOCAL_LLAMA, |
| ModelProvider.LOCAL_MISTRAL |
| ]) |
| |
| for provider in providers_to_try: |
| try: |
| return await self.generate_text(prompt, provider) |
| except Exception as e: |
| logger.warning(f"Provider {provider.value} failed: {e}") |
| continue |
| |
| raise RuntimeError("All providers failed") |