Spaces:
Build error
Build error
| # base_generator.py | |
| from abc import ABC, abstractmethod | |
| from typing import AsyncGenerator, Dict, Any, Optional, List, Tuple | |
| from dataclasses import dataclass | |
| from logging import getLogger | |
| from services.model_manager import ModelManager | |
| from services.cache import ResponseCache | |
| from services.batch_processor import BatchProcessor | |
| from services.health_check import HealthCheck | |
| from config.config import GenerationConfig, ModelConfig | |
| class BaseGenerator(ABC): | |
| """Base class for all generator implementations.""" | |
| def __init__( | |
| self, | |
| model_name: str, | |
| device: Optional[str] = None, | |
| default_generation_config: Optional[GenerationConfig] = None, | |
| model_config: Optional[ModelConfig] = None, | |
| cache_size: int = 1000, | |
| max_batch_size: int = 32 | |
| ): | |
| self.logger = getLogger(__name__) | |
| self.model_manager = ModelManager(device) | |
| self.cache = ResponseCache(cache_size) | |
| self.batch_processor = BatchProcessor(max_batch_size) | |
| self.health_check = HealthCheck() | |
| # self.tokenizer = self.model_manager.tokenizers[model_name] | |
| #self.tokenizer = self.load_tokenizer(llama_model_name) # Add this line to initialize the tokenizer | |
| self.default_config = default_generation_config or GenerationConfig() | |
| self.model_config = model_config or ModelConfig() | |
| async def generate_stream( | |
| self, | |
| prompt: str, | |
| config: Optional[GenerationConfig] = None | |
| ) -> AsyncGenerator[str, None]: | |
| pass | |
| def _get_generation_kwargs(self, config: GenerationConfig) -> Dict[str, Any]: | |
| pass | |
| def generate( | |
| self, | |
| prompt: str, | |
| model_kwargs: Dict[str, Any], | |
| strategy: str = "default", | |
| **kwargs | |
| ) -> str: | |
| pass |