MogensR commited on
Commit
038cd8a
·
verified ·
1 Parent(s): 646a504

Create ai_background.py

Browse files
Files changed (1) hide show
  1. processing/ai_background.py +219 -0
processing/ai_background.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ AI Background Generator Module
4
+ Handles Stable Diffusion background generation with proper dependency management.
5
+ """
6
+
7
+ import os
8
+ import sys
9
+ import tempfile
10
+ import random
11
+ import logging
12
+ from pathlib import Path
13
+ from typing import Optional
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ class AIBackgroundGenerator:
18
+ """
19
+ Stable Diffusion background generator with dependency isolation.
20
+ """
21
+
22
+ def __init__(self, temp_dir: Optional[str] = None):
23
+ self.temp_dir = temp_dir or tempfile.gettempdir()
24
+ self.available = False
25
+ self.error_message = None
26
+ self._check_dependencies()
27
+
28
+ def _check_dependencies(self):
29
+ """Check if required dependencies are available."""
30
+ try:
31
+ import torch
32
+ if not torch.cuda.is_available():
33
+ self.error_message = "CUDA not available - AI background generation requires GPU"
34
+ return
35
+
36
+ # Test diffusers import
37
+ from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
38
+
39
+ self.available = True
40
+ logger.info("AI Background Generator: Dependencies OK")
41
+
42
+ except ImportError as e:
43
+ if "torch.library" in str(e) and "custom_op" in str(e):
44
+ self.error_message = (
45
+ "PyTorch/Diffusers version mismatch. Please update:\n"
46
+ "pip install --upgrade torch diffusers transformers accelerate"
47
+ )
48
+ else:
49
+ self.error_message = f"Missing dependencies: {e}"
50
+ logger.warning(f"AI Background Generator unavailable: {self.error_message}")
51
+ except Exception as e:
52
+ self.error_message = f"Unexpected error checking dependencies: {e}"
53
+ logger.error(f"AI Background Generator error: {self.error_message}")
54
+
55
+ def is_available(self) -> bool:
56
+ """Check if AI background generation is available."""
57
+ return self.available
58
+
59
+ def get_error_message(self) -> Optional[str]:
60
+ """Get error message if dependencies are not available."""
61
+ return self.error_message
62
+
63
+ def generate_background(
64
+ self,
65
+ width: int,
66
+ height: int,
67
+ prompt: str,
68
+ init_image_path: Optional[str] = None,
69
+ model_id: str = "runwayml/stable-diffusion-v1-5",
70
+ num_steps: int = 25,
71
+ guidance_scale: float = 7.5,
72
+ strength: float = 0.6,
73
+ seed: Optional[int] = None,
74
+ ) -> str:
75
+ """
76
+ Generate AI background image.
77
+
78
+ Returns:
79
+ Path to generated background image
80
+
81
+ Raises:
82
+ RuntimeError: If dependencies are not available or generation fails
83
+ """
84
+ if not self.available:
85
+ raise RuntimeError(f"AI Background not available: {self.error_message}")
86
+
87
+ try:
88
+ # Import here to avoid issues if dependencies not available
89
+ import torch
90
+ from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
91
+ from PIL import Image
92
+
93
+ device = "cuda" if torch.cuda.is_available() else "cpu"
94
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
95
+
96
+ # Setup generator
97
+ generator = torch.Generator(device=device)
98
+ if seed is None:
99
+ seed = random.randint(0, 2**31 - 1)
100
+ generator.manual_seed(seed)
101
+
102
+ logger.info(f"Generating {width}x{height} background: '{prompt}' (seed: {seed})")
103
+
104
+ # Choose pipeline based on whether we have an init image
105
+ if init_image_path and os.path.exists(init_image_path):
106
+ # Image-to-image pipeline
107
+ logger.info("Using img2img pipeline")
108
+ pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
109
+ model_id,
110
+ torch_dtype=torch_dtype,
111
+ safety_checker=None,
112
+ requires_safety_checker=False
113
+ ).to(device)
114
+
115
+ # Enable memory efficient attention if available
116
+ try:
117
+ pipe.enable_attention_slicing()
118
+ pipe.enable_model_cpu_offload()
119
+ except AttributeError:
120
+ pass
121
+
122
+ # Load and resize init image
123
+ init_image = Image.open(init_image_path).convert("RGB")
124
+ init_image = init_image.resize((width, height), Image.LANCZOS)
125
+
126
+ # Generate
127
+ result = pipe(
128
+ prompt=prompt,
129
+ image=init_image,
130
+ strength=strength,
131
+ num_inference_steps=num_steps,
132
+ guidance_scale=guidance_scale,
133
+ generator=generator,
134
+ height=height,
135
+ width=width
136
+ ).images[0]
137
+
138
+ else:
139
+ # Text-to-image pipeline
140
+ logger.info("Using txt2img pipeline")
141
+ pipe = StableDiffusionPipeline.from_pretrained(
142
+ model_id,
143
+ torch_dtype=torch_dtype,
144
+ safety_checker=None,
145
+ requires_safety_checker=False
146
+ ).to(device)
147
+
148
+ # Enable memory efficient attention if available
149
+ try:
150
+ pipe.enable_attention_slicing()
151
+ pipe.enable_model_cpu_offload()
152
+ except AttributeError:
153
+ pass
154
+
155
+ # Generate
156
+ result = pipe(
157
+ prompt=prompt,
158
+ height=height,
159
+ width=width,
160
+ num_inference_steps=num_steps,
161
+ guidance_scale=guidance_scale,
162
+ generator=generator
163
+ ).images[0]
164
+
165
+ # Save result
166
+ output_path = os.path.join(
167
+ self.temp_dir,
168
+ f"ai_bg_{int(os.times().elapsed)}_{seed:08x}.jpg"
169
+ )
170
+ result.save(output_path, quality=95, optimize=True)
171
+
172
+ # Cleanup GPU memory
173
+ try:
174
+ del pipe
175
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
176
+ except Exception:
177
+ pass
178
+
179
+ logger.info(f"AI background generated: {output_path}")
180
+ return output_path
181
+
182
+ except Exception as e:
183
+ logger.error(f"AI background generation failed: {e}")
184
+ raise RuntimeError(f"Background generation failed: {e}")
185
+
186
+
187
+ # Convenience function for easy import
188
+ def create_ai_background_generator(temp_dir: Optional[str] = None) -> AIBackgroundGenerator:
189
+ """Factory function to create AI background generator."""
190
+ return AIBackgroundGenerator(temp_dir)
191
+
192
+
193
+ # Test function
194
+ def test_ai_background():
195
+ """Test AI background generation."""
196
+ generator = create_ai_background_generator()
197
+
198
+ if not generator.is_available():
199
+ print(f"AI Background not available: {generator.get_error_message()}")
200
+ return False
201
+
202
+ try:
203
+ # Test with simple prompt
204
+ bg_path = generator.generate_background(
205
+ width=512,
206
+ height=512,
207
+ prompt="professional office background with soft lighting",
208
+ num_steps=20
209
+ )
210
+ print(f"Test successful: {bg_path}")
211
+ return True
212
+ except Exception as e:
213
+ print(f"Test failed: {e}")
214
+ return False
215
+
216
+
217
+ if __name__ == "__main__":
218
+ # Run test when executed directly
219
+ test_ai_background()