jmisak's picture
Upload 19 files
f277022 verified
"""Model management and text generation service."""
import hashlib
import time
from functools import lru_cache
from typing import Any, Dict, Optional
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
from writing_studio.core.config import settings
from writing_studio.core.exceptions import ModelLoadError, TextGenerationError
from writing_studio.utils.logging import logger
from writing_studio.utils.validation import validate_generation_params, validate_model_name
class ModelService:
"""Service for managing language models and text generation."""
def __init__(self):
"""Initialize the model service."""
self._current_model: Optional[Any] = None
self._current_model_name: Optional[str] = None
self._task_type: str = "text2text-generation" # Default for FLAN-T5
self._cache: Dict[str, Any] = {}
self._load_default_model()
def _load_default_model(self) -> None:
"""Load the default model at initialization."""
try:
logger.info(f"Loading default model: {settings.default_model}")
self.load_model(settings.default_model)
except Exception as e:
logger.error(f"Failed to load default model: {e}")
raise ModelLoadError(
f"Failed to load default model: {settings.default_model}",
{"error": str(e)},
)
def load_model(self, model_name: str) -> None:
"""
Load a language model from HuggingFace.
Args:
model_name: HuggingFace model identifier
Raises:
ModelLoadError: If model loading fails
"""
try:
# Validate model name
model_name = validate_model_name(model_name)
# Check if already loaded
if self._current_model_name == model_name:
logger.debug(f"Model {model_name} already loaded")
return
logger.info(f"Loading model: {model_name}")
start_time = time.time()
# Detect model type and use appropriate pipeline
# FLAN-T5, T5 = text2text-generation
# GPT-2, GPT = text-generation
if any(x in model_name.lower() for x in ['t5', 'flan']):
task = "text2text-generation"
logger.info(f"Detected instruction-following model, using {task} pipeline")
else:
task = "text-generation"
logger.info(f"Detected text generation model, using {task} pipeline")
# Load model with error handling
self._current_model = pipeline(
task,
model=model_name,
max_length=settings.max_model_length,
)
self._current_model_name = model_name
self._task_type = task
load_time = time.time() - start_time
logger.info(f"Model loaded successfully in {load_time:.2f}s: {model_name}")
except Exception as e:
logger.error(f"Failed to load model {model_name}: {e}")
raise ModelLoadError(
f"Failed to load model: {model_name}", {"error": str(e)}
)
def generate_text(
self,
prompt: str,
max_length: Optional[int] = None,
num_sequences: Optional[int] = None,
temperature: float = 1.0,
use_cache: bool = True,
) -> str:
"""
Generate text using the loaded model.
Args:
prompt: Input prompt for generation
max_length: Maximum generation length
num_sequences: Number of sequences to generate
temperature: Sampling temperature
use_cache: Whether to use caching
Returns:
Generated text
Raises:
TextGenerationError: If generation fails
"""
if self._current_model is None:
raise TextGenerationError("No model loaded")
# Use defaults if not provided
max_length = max_length or settings.default_max_length
num_sequences = num_sequences or settings.default_num_sequences
# Validate parameters
params = validate_generation_params(max_length, num_sequences, temperature)
# Check cache if enabled
if use_cache and settings.enable_cache:
cache_key = self._get_cache_key(prompt, params)
if cache_key in self._cache:
logger.debug("Returning cached result")
return self._cache[cache_key]
try:
logger.info(f"Generating text with model: {self._current_model_name}")
start_time = time.time()
# Generate text with parameters appropriate for model type
if self._task_type == "text2text-generation":
# T5/FLAN-T5 models
result = self._current_model(
prompt,
max_new_tokens=params["max_length"],
num_return_sequences=params["num_sequences"],
do_sample=True,
temperature=params["temperature"],
truncation=True,
)
# T5 models return generated_text directly
generated_text = result[0]["generated_text"]
else:
# GPT-2 style models
result = self._current_model(
prompt,
max_length=params["max_length"],
num_return_sequences=params["num_sequences"],
do_sample=True,
temperature=params["temperature"],
pad_token_id=self._current_model.tokenizer.eos_token_id,
)
generated_text = result[0]["generated_text"]
generation_time = time.time() - start_time
logger.info(f"Text generated in {generation_time:.2f}s")
# Cache result if enabled
if use_cache and settings.enable_cache:
self._cache_result(cache_key, generated_text)
return generated_text
except Exception as e:
logger.error(f"Text generation failed: {e}")
raise TextGenerationError("Text generation failed", {"error": str(e)})
def _get_cache_key(self, prompt: str, params: dict) -> str:
"""
Generate cache key for prompt and parameters.
Args:
prompt: Input prompt
params: Generation parameters
Returns:
Cache key hash
"""
key_str = f"{prompt}:{params['max_length']}:{params['num_sequences']}:{params['temperature']}"
return hashlib.sha256(key_str.encode()).hexdigest()
def _cache_result(self, key: str, result: str) -> None:
"""
Cache generation result with size limit.
Args:
key: Cache key
result: Result to cache
"""
if len(self._cache) >= settings.cache_max_size:
# Remove oldest entry (simple FIFO)
self._cache.pop(next(iter(self._cache)))
self._cache[key] = result
def clear_cache(self) -> None:
"""Clear the generation cache."""
self._cache.clear()
logger.info("Generation cache cleared")
def get_model_info(self) -> Dict[str, Any]:
"""
Get information about the currently loaded model.
Returns:
Model information dictionary
"""
return {
"model_name": self._current_model_name,
"cache_size": len(self._cache),
"cache_enabled": settings.enable_cache,
}
# Global model service instance
@lru_cache(maxsize=1)
def get_model_service() -> ModelService:
"""Get the global model service instance."""
return ModelService()