Innovideo / model.py
Bossmarc747's picture
oj
32cd713
"""
Model initialization and inference logic for image generation.
This module handles loading the diffusion model and provides functions
for generating images from text prompts with error handling.
"""
import logging
import random
from typing import Tuple, Optional, Union
import numpy as np
import torch
from diffusers import DiffusionPipeline
from PIL import Image
from config import MODEL_REPO_ID, MAX_SEED
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class ModelManager:
"""Manages the diffusion model for image generation."""
def __init__(self):
"""Initialize the ModelManager and load the model."""
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
self.pipe = None
def load_model(self) -> None:
"""
Load the diffusion model from the specified repository.
Handles potential errors during model loading.
"""
try:
logger.info(f"Loading model {MODEL_REPO_ID} on {self.device} with {self.torch_dtype}")
self.pipe = DiffusionPipeline.from_pretrained(
MODEL_REPO_ID,
torch_dtype=self.torch_dtype
)
self.pipe = self.pipe.to(self.device)
logger.info("Model loaded successfully")
except Exception as e:
logger.error(f"Error loading model: {str(e)}")
raise RuntimeError(f"Failed to load model: {str(e)}")
def generate_image(
self,
prompt: str,
negative_prompt: str = "",
seed: int = 0,
randomize_seed: bool = True,
width: int = 1024,
height: int = 1024,
guidance_scale: float = 0.0,
num_inference_steps: int = 2,
progress_callback: Optional[callable] = None
) -> Tuple[Union[Image.Image, None], int]:
"""
Generate an image based on the provided prompt and parameters.
Args:
prompt: Text description of the desired image
negative_prompt: Text description of what to avoid in the image
seed: Random seed for reproducibility
randomize_seed: Whether to use a random seed
width: Width of the generated image
height: Height of the generated image
guidance_scale: How closely to follow the prompt
num_inference_steps: Number of denoising steps
progress_callback: Optional callback function for progress updates
Returns:
Tuple containing the generated image and the seed used
"""
if self.pipe is None:
logger.error("Model not loaded. Call load_model() first.")
return None, seed
# Validate inputs
if not prompt or prompt.strip() == "":
logger.warning("Empty prompt provided, using default")
prompt = "A beautiful landscape"
# Handle seed randomization
if randomize_seed:
seed = random.randint(0, MAX_SEED)
# Set up generator for reproducibility
generator = torch.Generator(device=self.device).manual_seed(seed)
try:
logger.info(f"Generating image with prompt: '{prompt}'")
# Generate the image
result = self.pipe(
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator,
callback=progress_callback
)
image = result.images[0]
logger.info(f"Image generated successfully with seed {seed}")
return image, seed
except Exception as e:
logger.error(f"Error generating image: {str(e)}")
return None, seed