File size: 3,126 Bytes
3aa90e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
"""
AI Image Generation Module
Uses Stable Diffusion to generate images from text prompts
"""

import torch
from diffusers import StableDiffusionPipeline
from PIL import Image
import os


class ImageGenerator:
    def __init__(self, model_id="stabilityai/stable-diffusion-2-1-base"):
        """Initialize the Stable Diffusion pipeline"""
        self.model_id = model_id
        self.pipe = None
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

    def load_model(self):
        """Load the Stable Diffusion model"""
        if self.pipe is None:
            print(f"Loading model on {self.device}...")
            self.pipe = StableDiffusionPipeline.from_pretrained(
                self.model_id,
                torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
                safety_checker=None
            )
            self.pipe = self.pipe.to(self.device)

            # Enable memory optimizations
            if self.device == "cuda":
                self.pipe.enable_attention_slicing()

            print("Model loaded successfully!")

    def generate(
        self,
        prompt: str,
        negative_prompt: str = "",
        num_inference_steps: int = 30,
        guidance_scale: float = 7.5,
        width: int = 512,
        height: int = 512,
        seed: int = None
    ) -> Image.Image:
        """
        Generate an image from a text prompt

        Args:
            prompt: Text description of desired image
            negative_prompt: What to avoid in the image
            num_inference_steps: Number of denoising steps (higher = better quality, slower)
            guidance_scale: How closely to follow the prompt (7-10 recommended)
            width: Image width (must be multiple of 8)
            height: Image height (must be multiple of 8)
            seed: Random seed for reproducibility

        Returns:
            PIL Image
        """
        self.load_model()

        # Set seed for reproducibility
        generator = None
        if seed is not None:
            generator = torch.Generator(device=self.device).manual_seed(seed)

        # Generate image
        with torch.inference_mode():
            result = self.pipe(
                prompt=prompt,
                negative_prompt=negative_prompt,
                num_inference_steps=num_inference_steps,
                guidance_scale=guidance_scale,
                width=width,
                height=height,
                generator=generator
            )

        return result.images[0]

    def unload_model(self):
        """Free up memory by unloading the model"""
        if self.pipe is not None:
            del self.pipe
            self.pipe = None
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            print("Model unloaded")


# Example usage
if __name__ == "__main__":
    generator = ImageGenerator()
    image = generator.generate(
        prompt="A fantasy landscape with mountains and a castle at sunset",
        seed=42
    )
    image.save("test_generated.png")
    print("Image saved as test_generated.png")