Spaces:
Sleeping
Sleeping
| """Model management and text generation service.""" | |
| import hashlib | |
| import time | |
| from functools import lru_cache | |
| from typing import Any, Dict, Optional | |
| from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM | |
| from writing_studio.core.config import settings | |
| from writing_studio.core.exceptions import ModelLoadError, TextGenerationError | |
| from writing_studio.utils.logging import logger | |
| from writing_studio.utils.validation import validate_generation_params, validate_model_name | |
| class ModelService: | |
| """Service for managing language models and text generation.""" | |
| def __init__(self): | |
| """Initialize the model service.""" | |
| self._current_model: Optional[Any] = None | |
| self._current_model_name: Optional[str] = None | |
| self._task_type: str = "text2text-generation" # Default for FLAN-T5 | |
| self._cache: Dict[str, Any] = {} | |
| self._load_default_model() | |
| def _load_default_model(self) -> None: | |
| """Load the default model at initialization.""" | |
| try: | |
| logger.info(f"Loading default model: {settings.default_model}") | |
| self.load_model(settings.default_model) | |
| except Exception as e: | |
| logger.error(f"Failed to load default model: {e}") | |
| raise ModelLoadError( | |
| f"Failed to load default model: {settings.default_model}", | |
| {"error": str(e)}, | |
| ) | |
| def load_model(self, model_name: str) -> None: | |
| """ | |
| Load a language model from HuggingFace. | |
| Args: | |
| model_name: HuggingFace model identifier | |
| Raises: | |
| ModelLoadError: If model loading fails | |
| """ | |
| try: | |
| # Validate model name | |
| model_name = validate_model_name(model_name) | |
| # Check if already loaded | |
| if self._current_model_name == model_name: | |
| logger.debug(f"Model {model_name} already loaded") | |
| return | |
| logger.info(f"Loading model: {model_name}") | |
| start_time = time.time() | |
| # Detect model type and use appropriate pipeline | |
| # FLAN-T5, T5 = text2text-generation | |
| # GPT-2, GPT = text-generation | |
| if any(x in model_name.lower() for x in ['t5', 'flan']): | |
| task = "text2text-generation" | |
| logger.info(f"Detected instruction-following model, using {task} pipeline") | |
| else: | |
| task = "text-generation" | |
| logger.info(f"Detected text generation model, using {task} pipeline") | |
| # Load model with error handling | |
| self._current_model = pipeline( | |
| task, | |
| model=model_name, | |
| max_length=settings.max_model_length, | |
| ) | |
| self._current_model_name = model_name | |
| self._task_type = task | |
| load_time = time.time() - start_time | |
| logger.info(f"Model loaded successfully in {load_time:.2f}s: {model_name}") | |
| except Exception as e: | |
| logger.error(f"Failed to load model {model_name}: {e}") | |
| raise ModelLoadError( | |
| f"Failed to load model: {model_name}", {"error": str(e)} | |
| ) | |
| def generate_text( | |
| self, | |
| prompt: str, | |
| max_length: Optional[int] = None, | |
| num_sequences: Optional[int] = None, | |
| temperature: float = 1.0, | |
| use_cache: bool = True, | |
| ) -> str: | |
| """ | |
| Generate text using the loaded model. | |
| Args: | |
| prompt: Input prompt for generation | |
| max_length: Maximum generation length | |
| num_sequences: Number of sequences to generate | |
| temperature: Sampling temperature | |
| use_cache: Whether to use caching | |
| Returns: | |
| Generated text | |
| Raises: | |
| TextGenerationError: If generation fails | |
| """ | |
| if self._current_model is None: | |
| raise TextGenerationError("No model loaded") | |
| # Use defaults if not provided | |
| max_length = max_length or settings.default_max_length | |
| num_sequences = num_sequences or settings.default_num_sequences | |
| # Validate parameters | |
| params = validate_generation_params(max_length, num_sequences, temperature) | |
| # Check cache if enabled | |
| if use_cache and settings.enable_cache: | |
| cache_key = self._get_cache_key(prompt, params) | |
| if cache_key in self._cache: | |
| logger.debug("Returning cached result") | |
| return self._cache[cache_key] | |
| try: | |
| logger.info(f"Generating text with model: {self._current_model_name}") | |
| start_time = time.time() | |
| # Generate text with parameters appropriate for model type | |
| if self._task_type == "text2text-generation": | |
| # T5/FLAN-T5 models | |
| result = self._current_model( | |
| prompt, | |
| max_new_tokens=params["max_length"], | |
| num_return_sequences=params["num_sequences"], | |
| do_sample=True, | |
| temperature=params["temperature"], | |
| truncation=True, | |
| ) | |
| # T5 models return generated_text directly | |
| generated_text = result[0]["generated_text"] | |
| else: | |
| # GPT-2 style models | |
| result = self._current_model( | |
| prompt, | |
| max_length=params["max_length"], | |
| num_return_sequences=params["num_sequences"], | |
| do_sample=True, | |
| temperature=params["temperature"], | |
| pad_token_id=self._current_model.tokenizer.eos_token_id, | |
| ) | |
| generated_text = result[0]["generated_text"] | |
| generation_time = time.time() - start_time | |
| logger.info(f"Text generated in {generation_time:.2f}s") | |
| # Cache result if enabled | |
| if use_cache and settings.enable_cache: | |
| self._cache_result(cache_key, generated_text) | |
| return generated_text | |
| except Exception as e: | |
| logger.error(f"Text generation failed: {e}") | |
| raise TextGenerationError("Text generation failed", {"error": str(e)}) | |
| def _get_cache_key(self, prompt: str, params: dict) -> str: | |
| """ | |
| Generate cache key for prompt and parameters. | |
| Args: | |
| prompt: Input prompt | |
| params: Generation parameters | |
| Returns: | |
| Cache key hash | |
| """ | |
| key_str = f"{prompt}:{params['max_length']}:{params['num_sequences']}:{params['temperature']}" | |
| return hashlib.sha256(key_str.encode()).hexdigest() | |
| def _cache_result(self, key: str, result: str) -> None: | |
| """ | |
| Cache generation result with size limit. | |
| Args: | |
| key: Cache key | |
| result: Result to cache | |
| """ | |
| if len(self._cache) >= settings.cache_max_size: | |
| # Remove oldest entry (simple FIFO) | |
| self._cache.pop(next(iter(self._cache))) | |
| self._cache[key] = result | |
| def clear_cache(self) -> None: | |
| """Clear the generation cache.""" | |
| self._cache.clear() | |
| logger.info("Generation cache cleared") | |
| def get_model_info(self) -> Dict[str, Any]: | |
| """ | |
| Get information about the currently loaded model. | |
| Returns: | |
| Model information dictionary | |
| """ | |
| return { | |
| "model_name": self._current_model_name, | |
| "cache_size": len(self._cache), | |
| "cache_enabled": settings.enable_cache, | |
| } | |
| # Global model service instance | |
| def get_model_service() -> ModelService: | |
| """Get the global model service instance.""" | |
| return ModelService() | |