Spaces:
Sleeping
Sleeping
| """ | |
| Model initialization and inference logic for image generation. | |
| This module handles loading the diffusion model and provides functions | |
| for generating images from text prompts with error handling. | |
| """ | |
| import logging | |
| import random | |
| from typing import Tuple, Optional, Union | |
| import numpy as np | |
| import torch | |
| from diffusers import DiffusionPipeline | |
| from PIL import Image | |
| from config import MODEL_REPO_ID, MAX_SEED | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| class ModelManager: | |
| """Manages the diffusion model for image generation.""" | |
| def __init__(self): | |
| """Initialize the ModelManager and load the model.""" | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| self.pipe = None | |
| def load_model(self) -> None: | |
| """ | |
| Load the diffusion model from the specified repository. | |
| Handles potential errors during model loading. | |
| """ | |
| try: | |
| logger.info(f"Loading model {MODEL_REPO_ID} on {self.device} with {self.torch_dtype}") | |
| self.pipe = DiffusionPipeline.from_pretrained( | |
| MODEL_REPO_ID, | |
| torch_dtype=self.torch_dtype | |
| ) | |
| self.pipe = self.pipe.to(self.device) | |
| logger.info("Model loaded successfully") | |
| except Exception as e: | |
| logger.error(f"Error loading model: {str(e)}") | |
| raise RuntimeError(f"Failed to load model: {str(e)}") | |
| def generate_image( | |
| self, | |
| prompt: str, | |
| negative_prompt: str = "", | |
| seed: int = 0, | |
| randomize_seed: bool = True, | |
| width: int = 1024, | |
| height: int = 1024, | |
| guidance_scale: float = 0.0, | |
| num_inference_steps: int = 2, | |
| progress_callback: Optional[callable] = None | |
| ) -> Tuple[Union[Image.Image, None], int]: | |
| """ | |
| Generate an image based on the provided prompt and parameters. | |
| Args: | |
| prompt: Text description of the desired image | |
| negative_prompt: Text description of what to avoid in the image | |
| seed: Random seed for reproducibility | |
| randomize_seed: Whether to use a random seed | |
| width: Width of the generated image | |
| height: Height of the generated image | |
| guidance_scale: How closely to follow the prompt | |
| num_inference_steps: Number of denoising steps | |
| progress_callback: Optional callback function for progress updates | |
| Returns: | |
| Tuple containing the generated image and the seed used | |
| """ | |
| if self.pipe is None: | |
| logger.error("Model not loaded. Call load_model() first.") | |
| return None, seed | |
| # Validate inputs | |
| if not prompt or prompt.strip() == "": | |
| logger.warning("Empty prompt provided, using default") | |
| prompt = "A beautiful landscape" | |
| # Handle seed randomization | |
| if randomize_seed: | |
| seed = random.randint(0, MAX_SEED) | |
| # Set up generator for reproducibility | |
| generator = torch.Generator(device=self.device).manual_seed(seed) | |
| try: | |
| logger.info(f"Generating image with prompt: '{prompt}'") | |
| # Generate the image | |
| result = self.pipe( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| guidance_scale=guidance_scale, | |
| num_inference_steps=num_inference_steps, | |
| width=width, | |
| height=height, | |
| generator=generator, | |
| callback=progress_callback | |
| ) | |
| image = result.images[0] | |
| logger.info(f"Image generated successfully with seed {seed}") | |
| return image, seed | |
| except Exception as e: | |
| logger.error(f"Error generating image: {str(e)}") | |
| return None, seed |