| |
| from abc import ABC, abstractmethod |
| from typing import AsyncGenerator, Dict, Any, Optional, List, Tuple |
| from dataclasses import dataclass |
| from logging import getLogger |
| from services.model_manager import ModelManager |
| from services.cache import ResponseCache |
| from services.batch_processor import BatchProcessor |
| from services.health_check import HealthCheck |
|
|
| from config.config import GenerationConfig, ModelConfig |
|
|
| class BaseGenerator(ABC): |
| """Base class for all generator implementations.""" |
| |
| def __init__( |
| self, |
| model_name: str, |
| device: Optional[str] = None, |
| default_generation_config: Optional[GenerationConfig] = None, |
| model_config: Optional[ModelConfig] = None, |
| cache_size: int = 1000, |
| max_batch_size: int = 32 |
| ): |
| self.logger = getLogger(__name__) |
| self.model_manager = ModelManager(device) |
| self.cache = ResponseCache(cache_size) |
| self.batch_processor = BatchProcessor(max_batch_size) |
| self.health_check = HealthCheck() |
| |
| |
| self.default_config = default_generation_config or GenerationConfig() |
| self.model_config = model_config or ModelConfig() |
| |
| @abstractmethod |
| async def generate_stream( |
| self, |
| prompt: str, |
| config: Optional[GenerationConfig] = None |
| ) -> AsyncGenerator[str, None]: |
| pass |
| |
| @abstractmethod |
| def _get_generation_kwargs(self, config: GenerationConfig) -> Dict[str, Any]: |
| pass |
| |
| @abstractmethod |
| def generate( |
| self, |
| prompt: str, |
| model_kwargs: Dict[str, Any], |
| strategy: str = "default", |
| **kwargs |
| ) -> str: |
| pass |