SanskarModi commited on
Commit
2b4ba87
·
1 Parent(s): 2a72dcc

added image to image generation

Browse files
Files changed (1) hide show
  1. app/img2img.py +175 -1
app/img2img.py CHANGED
@@ -1 +1,175 @@
1
- """Auto-generated placeholder module for Stable Diffusion Image Generator."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Image-to-image generation using Stable Diffusion.
2
+
3
+ This module provides:
4
+ - prepare_img2img_pipeline: build an Img2Img pipeline from an existing txt2img pipe.
5
+ - generate_img2img: run image-to-image generation and return (PIL.Image, metadata).
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import time
11
+ from pathlib import Path
12
+ from typing import Any, Dict, Optional, Union
13
+
14
+ import torch
15
+ from diffusers import StableDiffusionImg2ImgPipeline
16
+ from PIL import Image
17
+
18
+ from app.utils.logger import get_logger
19
+
20
+ logger = get_logger(__name__)
21
+
22
+
23
+ def _validate_resolution(width: int, height: int) -> tuple[int, int]:
24
+ """Clamp resolution to a safe range and snap to multiples of 64."""
25
+ width = max(256, min(width, 768))
26
+ height = max(256, min(height, 768))
27
+ width = (width // 64) * 64
28
+ height = (height // 64) * 64
29
+ return int(width), int(height)
30
+
31
+
32
+ def _load_init_image(
33
+ image: Union[Image.Image, str, Path],
34
+ width: int,
35
+ height: int,
36
+ ) -> Image.Image:
37
+ """Load and preprocess the init image for img2img."""
38
+ if isinstance(image, (str, Path)):
39
+ image = Image.open(image)
40
+
41
+ if not isinstance(image, Image.Image):
42
+ raise TypeError("init_image must be a PIL.Image or a valid image path.")
43
+
44
+ image = image.convert("RGB")
45
+ image = image.resize((width, height), resample=Image.LANCZOS)
46
+ return image
47
+
48
+
49
+ def prepare_img2img_pipeline(
50
+ base_pipe,
51
+ model_id: str = "runwayml/stable-diffusion-v1-5",
52
+ ) -> StableDiffusionImg2ImgPipeline:
53
+ """Create an Img2Img pipeline that shares weights with the base txt2img pipe.
54
+
55
+ Tries to use StableDiffusionImg2ImgPipeline.from_pipe to reuse:
56
+ - UNet
57
+ - VAE
58
+ - text encoder
59
+ - tokenizer
60
+ - scheduler
61
+ """
62
+ try:
63
+ img2img_pipe = StableDiffusionImg2ImgPipeline.from_pipe(base_pipe)
64
+ logger.info("Created Img2Img pipeline from existing base pipeline.")
65
+ except Exception as err:
66
+ logger.info("from_pipe failed (%s); falling back to from_pretrained.", err)
67
+ img2img_pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
68
+ model_id,
69
+ torch_dtype=base_pipe.unet.dtype,
70
+ safety_checker=None,
71
+ )
72
+ device = next(base_pipe.unet.parameters()).device
73
+ img2img_pipe = img2img_pipe.to(device)
74
+
75
+ # memory optimizations similar to txt2img pipeline
76
+ try:
77
+ img2img_pipe.enable_attention_slicing()
78
+ logger.info("Enabled attention slicing on Img2Img pipeline.")
79
+ except Exception:
80
+ logger.info("Attention slicing not available on Img2Img pipeline.")
81
+
82
+ try:
83
+ if hasattr(img2img_pipe.vae, "enable_tiling"):
84
+ img2img_pipe.vae.enable_tiling()
85
+ logger.info("Enabled VAE tiling on Img2Img pipeline.")
86
+ except Exception:
87
+ pass
88
+
89
+ return img2img_pipe
90
+
91
+
92
+ def generate_img2img(
93
+ pipe: StableDiffusionImg2ImgPipeline,
94
+ init_image: Union[Image.Image, str, Path],
95
+ prompt: str,
96
+ negative_prompt: Optional[str] = None,
97
+ strength: float = 0.7,
98
+ steps: int = 30,
99
+ guidance_scale: float = 7.5,
100
+ width: int = 512,
101
+ height: int = 512,
102
+ seed: Optional[int] = None,
103
+ device: str = "cuda",
104
+ ) -> tuple[Image.Image, Dict[str, Any]]:
105
+ """Run image-to-image generation.
106
+
107
+ Args:
108
+ pipe: A StableDiffusionImg2ImgPipeline.
109
+ init_image: Base image (PIL or path).
110
+ prompt: Text prompt to guide the transformation.
111
+ negative_prompt: What to avoid in the output.
112
+ strength: How strong the transformation is (0-1).
113
+ steps: Number of inference steps.
114
+ guidance_scale: Prompt adherence strength.
115
+ width: Target width (snapped to 64 multiple).
116
+ height: Target height (snapped to 64 multiple).
117
+ seed: Optional random seed for reproducibility.
118
+ device: "cuda" or "cpu".
119
+
120
+ Returns:
121
+ (PIL.Image, metadata dict)
122
+ """
123
+ if not (0.0 < strength <= 1.0):
124
+ raise ValueError("strength must be in (0, 1].")
125
+
126
+ start = time.time()
127
+ width, height = _validate_resolution(width, height)
128
+ init_image = _load_init_image(init_image, width, height)
129
+
130
+ # Seed handling
131
+ if seed is None:
132
+ seed = int(torch.seed() & ((1 << 63) - 1))
133
+
134
+ gen = torch.Generator(device if device != "cpu" else "cpu").manual_seed(int(seed))
135
+
136
+ logger.info(
137
+ "Img2Img: steps=%s cfg=%s strength=%.2f res=%sx%s seed=%s",
138
+ steps,
139
+ guidance_scale,
140
+ strength,
141
+ width,
142
+ height,
143
+ seed,
144
+ )
145
+
146
+ device_type = "cuda" if device != "cpu" else "cpu"
147
+ with torch.autocast(device_type=device_type):
148
+ result = pipe(
149
+ prompt=prompt,
150
+ negative_prompt=negative_prompt if negative_prompt else None,
151
+ image=init_image,
152
+ strength=float(strength),
153
+ num_inference_steps=int(steps),
154
+ guidance_scale=float(guidance_scale),
155
+ generator=gen,
156
+ )
157
+
158
+ out_image = result.images[0]
159
+ elapsed = time.time() - start
160
+
161
+ metadata: Dict[str, Any] = {
162
+ "mode": "img2img",
163
+ "prompt": prompt,
164
+ "negative_prompt": negative_prompt,
165
+ "steps": steps,
166
+ "guidance_scale": guidance_scale,
167
+ "width": width,
168
+ "height": height,
169
+ "seed": int(seed),
170
+ "strength": float(strength),
171
+ "elapsed_seconds": elapsed,
172
+ }
173
+
174
+ logger.info("Img2Img finished in %.2fs", elapsed)
175
+ return out_image, metadata