File size: 4,157 Bytes
32cd713
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
"""
Model initialization and inference logic for image generation.

This module handles loading the diffusion model and provides functions
for generating images from text prompts with error handling.
"""

import logging
import random
from typing import Tuple, Optional, Union

import numpy as np
import torch
from diffusers import DiffusionPipeline
from PIL import Image

from config import MODEL_REPO_ID, MAX_SEED

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class ModelManager:
    """Manages the diffusion model for image generation."""
    
    def __init__(self):
        """Initialize the ModelManager and load the model."""
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
        self.pipe = None
        
    def load_model(self) -> None:
        """
        Load the diffusion model from the specified repository.
        
        Handles potential errors during model loading.
        """
        try:
            logger.info(f"Loading model {MODEL_REPO_ID} on {self.device} with {self.torch_dtype}")
            self.pipe = DiffusionPipeline.from_pretrained(
                MODEL_REPO_ID, 
                torch_dtype=self.torch_dtype
            )
            self.pipe = self.pipe.to(self.device)
            logger.info("Model loaded successfully")
        except Exception as e:
            logger.error(f"Error loading model: {str(e)}")
            raise RuntimeError(f"Failed to load model: {str(e)}")
    
    def generate_image(
        self,
        prompt: str,
        negative_prompt: str = "",
        seed: int = 0,
        randomize_seed: bool = True,
        width: int = 1024,
        height: int = 1024,
        guidance_scale: float = 0.0,
        num_inference_steps: int = 2,
        progress_callback: Optional[callable] = None
    ) -> Tuple[Union[Image.Image, None], int]:
        """
        Generate an image based on the provided prompt and parameters.
        
        Args:
            prompt: Text description of the desired image
            negative_prompt: Text description of what to avoid in the image
            seed: Random seed for reproducibility
            randomize_seed: Whether to use a random seed
            width: Width of the generated image
            height: Height of the generated image
            guidance_scale: How closely to follow the prompt
            num_inference_steps: Number of denoising steps
            progress_callback: Optional callback function for progress updates
            
        Returns:
            Tuple containing the generated image and the seed used
        """
        if self.pipe is None:
            logger.error("Model not loaded. Call load_model() first.")
            return None, seed
            
        # Validate inputs
        if not prompt or prompt.strip() == "":
            logger.warning("Empty prompt provided, using default")
            prompt = "A beautiful landscape"
            
        # Handle seed randomization
        if randomize_seed:
            seed = random.randint(0, MAX_SEED)
        
        # Set up generator for reproducibility
        generator = torch.Generator(device=self.device).manual_seed(seed)
        
        try:
            logger.info(f"Generating image with prompt: '{prompt}'")
            
            # Generate the image
            result = self.pipe(
                prompt=prompt,
                negative_prompt=negative_prompt,
                guidance_scale=guidance_scale,
                num_inference_steps=num_inference_steps,
                width=width,
                height=height,
                generator=generator,
                callback=progress_callback
            )
            
            image = result.images[0]
            logger.info(f"Image generated successfully with seed {seed}")
            return image, seed
            
        except Exception as e:
            logger.error(f"Error generating image: {str(e)}")
            return None, seed