"""Model management and text generation service.""" import hashlib import time from functools import lru_cache from typing import Any, Dict, Optional from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM from writing_studio.core.config import settings from writing_studio.core.exceptions import ModelLoadError, TextGenerationError from writing_studio.utils.logging import logger from writing_studio.utils.validation import validate_generation_params, validate_model_name class ModelService: """Service for managing language models and text generation.""" def __init__(self): """Initialize the model service.""" self._current_model: Optional[Any] = None self._current_model_name: Optional[str] = None self._task_type: str = "text2text-generation" # Default for FLAN-T5 self._cache: Dict[str, Any] = {} self._load_default_model() def _load_default_model(self) -> None: """Load the default model at initialization.""" try: logger.info(f"Loading default model: {settings.default_model}") self.load_model(settings.default_model) except Exception as e: logger.error(f"Failed to load default model: {e}") raise ModelLoadError( f"Failed to load default model: {settings.default_model}", {"error": str(e)}, ) def load_model(self, model_name: str) -> None: """ Load a language model from HuggingFace. Args: model_name: HuggingFace model identifier Raises: ModelLoadError: If model loading fails """ try: # Validate model name model_name = validate_model_name(model_name) # Check if already loaded if self._current_model_name == model_name: logger.debug(f"Model {model_name} already loaded") return logger.info(f"Loading model: {model_name}") start_time = time.time() # Detect model type and use appropriate pipeline # FLAN-T5, T5 = text2text-generation # GPT-2, GPT = text-generation if any(x in model_name.lower() for x in ['t5', 'flan']): task = "text2text-generation" logger.info(f"Detected instruction-following model, using {task} pipeline") else: task = "text-generation" logger.info(f"Detected text generation model, using {task} pipeline") # Load model with error handling self._current_model = pipeline( task, model=model_name, max_length=settings.max_model_length, ) self._current_model_name = model_name self._task_type = task load_time = time.time() - start_time logger.info(f"Model loaded successfully in {load_time:.2f}s: {model_name}") except Exception as e: logger.error(f"Failed to load model {model_name}: {e}") raise ModelLoadError( f"Failed to load model: {model_name}", {"error": str(e)} ) def generate_text( self, prompt: str, max_length: Optional[int] = None, num_sequences: Optional[int] = None, temperature: float = 1.0, use_cache: bool = True, ) -> str: """ Generate text using the loaded model. Args: prompt: Input prompt for generation max_length: Maximum generation length num_sequences: Number of sequences to generate temperature: Sampling temperature use_cache: Whether to use caching Returns: Generated text Raises: TextGenerationError: If generation fails """ if self._current_model is None: raise TextGenerationError("No model loaded") # Use defaults if not provided max_length = max_length or settings.default_max_length num_sequences = num_sequences or settings.default_num_sequences # Validate parameters params = validate_generation_params(max_length, num_sequences, temperature) # Check cache if enabled if use_cache and settings.enable_cache: cache_key = self._get_cache_key(prompt, params) if cache_key in self._cache: logger.debug("Returning cached result") return self._cache[cache_key] try: logger.info(f"Generating text with model: {self._current_model_name}") start_time = time.time() # Generate text with parameters appropriate for model type if self._task_type == "text2text-generation": # T5/FLAN-T5 models result = self._current_model( prompt, max_new_tokens=params["max_length"], num_return_sequences=params["num_sequences"], do_sample=True, temperature=params["temperature"], truncation=True, ) # T5 models return generated_text directly generated_text = result[0]["generated_text"] else: # GPT-2 style models result = self._current_model( prompt, max_length=params["max_length"], num_return_sequences=params["num_sequences"], do_sample=True, temperature=params["temperature"], pad_token_id=self._current_model.tokenizer.eos_token_id, ) generated_text = result[0]["generated_text"] generation_time = time.time() - start_time logger.info(f"Text generated in {generation_time:.2f}s") # Cache result if enabled if use_cache and settings.enable_cache: self._cache_result(cache_key, generated_text) return generated_text except Exception as e: logger.error(f"Text generation failed: {e}") raise TextGenerationError("Text generation failed", {"error": str(e)}) def _get_cache_key(self, prompt: str, params: dict) -> str: """ Generate cache key for prompt and parameters. Args: prompt: Input prompt params: Generation parameters Returns: Cache key hash """ key_str = f"{prompt}:{params['max_length']}:{params['num_sequences']}:{params['temperature']}" return hashlib.sha256(key_str.encode()).hexdigest() def _cache_result(self, key: str, result: str) -> None: """ Cache generation result with size limit. Args: key: Cache key result: Result to cache """ if len(self._cache) >= settings.cache_max_size: # Remove oldest entry (simple FIFO) self._cache.pop(next(iter(self._cache))) self._cache[key] = result def clear_cache(self) -> None: """Clear the generation cache.""" self._cache.clear() logger.info("Generation cache cleared") def get_model_info(self) -> Dict[str, Any]: """ Get information about the currently loaded model. Returns: Model information dictionary """ return { "model_name": self._current_model_name, "cache_size": len(self._cache), "cache_enabled": settings.enable_cache, } # Global model service instance @lru_cache(maxsize=1) def get_model_service() -> ModelService: """Get the global model service instance.""" return ModelService()