alpercagann commited on
Commit
ee98090
·
1 Parent(s): e56f965

Implement simplified image generation

Browse files
Files changed (1) hide show
  1. controller.py +116 -235
controller.py CHANGED
@@ -1,11 +1,11 @@
1
  import os
2
  import sys
3
  import traceback
4
- from PIL import Image, ImageDraw, ImageFont
5
  import numpy as np
6
 
7
  class SonicDiffusionController:
8
- """Controller for SonicDiffusion with simplified model handling"""
9
 
10
  def __init__(self):
11
  self.model_loaded = False
@@ -19,10 +19,12 @@ class SonicDiffusionController:
19
  "assets/fire_crackling.wav": "1vOAZcbkpo_hre2g26n--lUXdwbTQp22k",
20
  "assets/plastic_bag.wav": "15igeDor7a47a-oluSCfO6GeUvFVl2ttb"
21
  }
22
- self.model_type = None
 
 
23
  self.audio_encoder = None
24
  self.audio_projector = None
25
- self.pipeline = None
26
 
27
  def _get_device(self):
28
  """Determine the available device (CPU or CUDA)"""
@@ -112,267 +114,146 @@ class SonicDiffusionController:
112
 
113
  def load_model(self, model_type="Landscape Model"):
114
  """Load the selected SonicDiffusion model"""
115
- status_messages = []
116
- status_messages.append(f"Loading {model_type}...")
117
-
118
- if model_type not in ["Landscape Model", "Greatest Hits Model"]:
119
- return f"Unknown model type: {model_type}"
 
 
 
 
 
 
 
 
 
 
 
120
 
121
- # Determine which assets we need
122
- if model_type == "Landscape Model":
123
- gate_dict_path = "ckpts/landscape.pt"
124
- audio_projector_path = "ckpts/audio_projector_landscape.pth"
125
- else:
126
- gate_dict_path = "ckpts/greatest_hits.pt"
127
- audio_projector_path = "ckpts/audio_projector_gh.pth"
128
 
129
- clap_weights = "ckpts/CLAP_weights_2022.pth"
130
-
131
- # Check if assets exist
132
- required_files = [gate_dict_path, audio_projector_path, clap_weights]
133
- missing_files = [f for f in required_files if not os.path.exists(f)]
134
-
135
- if missing_files:
136
- # Download missing files
137
- status_messages.append(f"Missing files: {', '.join(missing_files)}")
138
- status_messages.append("Downloading missing files...")
139
 
140
- for file_path in missing_files:
141
- if file_path in self.required_assets:
142
- try:
143
- from download_assets import download_gdrive_file
144
- success = download_gdrive_file(self.required_assets[file_path], file_path)
145
- status_messages.append(f"Downloaded {file_path}: {'Success' if success else 'Failed'}")
146
- except Exception as e:
147
- status_messages.append(f"Failed to download {file_path}: {str(e)}")
148
- return "\n".join(status_messages)
149
- else:
150
- status_messages.append(f"Missing required file {file_path} and no download source available")
151
- return "\n".join(status_messages)
152
-
153
- try:
154
- # Verify file availability
155
- for file_path in required_files:
156
- if not os.path.exists(file_path):
157
- status_messages.append(f"Required file {file_path} still missing after download attempt")
158
- return "\n".join(status_messages)
159
 
160
- # Simple loading of the model components
161
  try:
162
- import torch
163
- status_messages.append("✓ PyTorch available")
 
 
 
 
164
 
165
- # Load audio encoder stub
166
- try:
167
- self.audio_encoder = SimpleCLAPWrapper(clap_weights)
168
- status_messages.append("✓ CLAP encoder initialized")
169
- except Exception as e:
170
- status_messages.append(f"✗ CLAP encoder error: {str(e)}")
171
- return "\n".join(status_messages)
172
 
173
- # Load audio projector stub
174
- try:
175
- self.audio_projector = SimpleAudioProjector(audio_projector_path, self.device)
176
- status_messages.append("✓ Audio projector initialized")
177
- except Exception as e:
178
- status_messages.append(f"✗ Audio projector error: {str(e)}")
179
- return "\n".join(status_messages)
180
 
181
- # Load pipeline stub
182
- try:
183
- self.pipeline = SimpleDiffusionPipeline(gate_dict_path, self.device)
184
- status_messages.append("✓ Diffusion pipeline initialized")
185
- except Exception as e:
186
- status_messages.append(f"✗ Diffusion pipeline error: {str(e)}")
187
- return "\n".join(status_messages)
 
 
 
188
 
 
189
  self.model_loaded = True
190
- self.model_type = model_type
191
- status_messages.append(f"✓ {model_type} loaded successfully!")
192
 
193
- except ImportError as e:
194
- status_messages.append(f"Error importing required libraries: {str(e)}")
195
- return "\n".join(status_messages)
196
-
197
- return "\n".join(status_messages)
198
 
199
  except Exception as e:
200
  traceback.print_exc()
201
- status_messages.append(f"Error loading model: {str(e)}")
202
- return "\n".join(status_messages)
203
 
204
  def generate(self, text_prompt, audio_path=None, cfg_scale=7.5, steps=50):
205
  """Generate an image using SonicDiffusion with the specified inputs"""
206
  if not self.model_loaded:
207
- return self._create_error_image("Model not loaded. Please click 'Load Model' first.")
208
 
209
  if not audio_path:
210
- return self._create_error_image("Audio file is required")
211
 
212
  if not os.path.exists(audio_path):
213
- return self._create_error_image(f"Audio file {audio_path} does not exist")
214
 
215
  try:
216
- # Process audio through CLAP encoder
217
- audio_emb = self.audio_encoder.get_audio_embeddings(audio_path)
218
-
219
- # Process through audio projector
220
- audio_proj = self.audio_projector(audio_emb)
221
-
222
- # Create unconditional embedding
223
  import torch
224
- audio_emb_zero = torch.zeros(1, 1024).to(self.device)
225
- audio_uc = self.audio_projector(audio_emb_zero)
226
 
227
- # Combine for context
228
- audio_context = torch.cat([audio_uc, audio_proj]).to(self.device)
229
 
230
- # Generate image
231
- image = self.pipeline.generate(
232
- prompt=text_prompt,
233
- audio_context=audio_context,
234
- guidance_scale=cfg_scale,
235
- num_inference_steps=steps
236
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
 
238
- # Save the generated image
 
 
 
 
 
 
 
 
 
 
 
 
 
239
  os.makedirs("outputs", exist_ok=True)
240
- timestamp = self._get_timestamp()
241
- output_path = f"outputs/generated_{timestamp}.png"
242
- image.save(output_path)
243
 
244
- return image
245
 
246
  except Exception as e:
247
  traceback.print_exc()
248
- return self._create_error_image(f"Error during generation: {str(e)}")
249
-
250
- def _create_error_image(self, error_message):
251
- """Create an error image with the provided message"""
252
- img = Image.new('RGB', (512, 512), color=(255, 255, 255))
253
- draw = ImageDraw.Draw(img)
254
-
255
- # Draw a red border
256
- draw.rectangle([(0, 0), (511, 511)], outline=(255, 0, 0), width=5)
257
-
258
- # Draw the error message
259
- draw.text((20, 240), f"Error: {error_message}", fill=(0, 0, 0))
260
-
261
- return img
262
-
263
- def _get_timestamp(self):
264
- """Get current timestamp in string format"""
265
- from datetime import datetime
266
- return datetime.now().strftime("%Y%m%d_%H%M%S")
267
-
268
-
269
- # Simplified model components for demonstration
270
- class SimpleCLAPWrapper:
271
- """Simplified CLAP wrapper for audio encoding"""
272
-
273
- def __init__(self, weights_path):
274
- self.weights_path = weights_path
275
- self.sr = 44100
276
-
277
- # Just check if the weights file exists
278
- if not os.path.exists(weights_path):
279
- raise ValueError(f"CLAP weights file not found: {weights_path}")
280
-
281
- def get_audio_embeddings(self, audio_path):
282
- """Generate audio embeddings from the audio file"""
283
- import torch
284
- import librosa
285
-
286
- # Load the audio file
287
- try:
288
- audio, _ = librosa.load(audio_path, sr=self.sr, mono=True)
289
- except Exception as e:
290
- raise ValueError(f"Error loading audio file {audio_path}: {str(e)}")
291
-
292
- # Create a simple random embedding (since we don't have the real model)
293
- # This would normally be generated by the CLAP model
294
- torch.manual_seed(hash(audio_path) % 2**32)
295
- embedding = torch.randn(1, 1024)
296
-
297
- return embedding
298
-
299
-
300
- class SimpleAudioProjector:
301
- """Simplified audio projector for audio embedding processing"""
302
-
303
- def __init__(self, weights_path, device):
304
- self.weights_path = weights_path
305
- self.device = device
306
-
307
- # Just check if the weights file exists
308
- if not os.path.exists(weights_path):
309
- raise ValueError(f"Audio projector weights file not found: {weights_path}")
310
-
311
- def __call__(self, audio_embedding):
312
- """Process audio embeddings"""
313
- import torch
314
-
315
- # Create a simple transformation (since we don't have the real model)
316
- # This would normally be processed by the audio projector model
317
- torch.manual_seed(42)
318
- projection = torch.randn(1, 77, 768).to(self.device)
319
-
320
- return projection
321
-
322
-
323
- class SimpleDiffusionPipeline:
324
- """Simplified diffusion pipeline for image generation"""
325
-
326
- def __init__(self, weights_path, device):
327
- self.weights_path = weights_path
328
- self.device = device
329
-
330
- # Just check if the weights file exists
331
- if not os.path.exists(weights_path):
332
- raise ValueError(f"Pipeline weights file not found: {weights_path}")
333
-
334
- def generate(self, prompt, audio_context, guidance_scale=7.5, num_inference_steps=50):
335
- """Generate an image based on the prompt and audio context"""
336
- # Create a simple visualization of the audio context and prompt
337
- return self._create_visualized_output(prompt, audio_context, guidance_scale, num_inference_steps)
338
-
339
- def _create_visualized_output(self, prompt, audio_context, guidance_scale, num_inference_steps):
340
- """Create a visualization of the generation parameters"""
341
- import torch
342
- import numpy as np
343
- from PIL import Image, ImageDraw, ImageFont
344
-
345
- # Create a gradient background based on the audio context tensor
346
- # This is just for visualization since we don't have the real model
347
- audio_data = audio_context[1].detach().cpu().mean(dim=1).numpy()
348
- audio_data = (audio_data - audio_data.min()) / (audio_data.max() - audio_data.min())
349
-
350
- # Create a visualization
351
- img = Image.new('RGB', (512, 512), color=(255, 255, 255))
352
- draw = ImageDraw.Draw(img)
353
-
354
- # Draw a color gradient based on audio (simplified visualization)
355
- for y in range(512):
356
- # Get color from audio data
357
- idx = int(y / 512 * len(audio_data))
358
- if idx >= len(audio_data):
359
- idx = len(audio_data) - 1
360
-
361
- val = audio_data[idx]
362
- r = int(255 * (1 - val))
363
- g = int(200 * val)
364
- b = int(255 * (0.5 + 0.5 * val))
365
-
366
- draw.line([(0, y), (512, y)], fill=(r, g, b))
367
-
368
- # Add the prompt text
369
- draw.rectangle([(10, 10), (502, 90)], fill=(255, 255, 255, 180))
370
- draw.text((20, 20), f"Prompt: {prompt}", fill=(0, 0, 0))
371
- draw.text((20, 40), f"CFG Scale: {guidance_scale}", fill=(0, 0, 0))
372
- draw.text((20, 60), f"Steps: {num_inference_steps}", fill=(0, 0, 0))
373
-
374
- # Add "Generated Image" label
375
- draw.rectangle([(10, 470), (502, 502)], fill=(255, 255, 255, 180))
376
- draw.text((20, 480), "Generated Image (Simulation)", fill=(0, 0, 0))
377
-
378
- return img
 
1
  import os
2
  import sys
3
  import traceback
4
+ from PIL import Image
5
  import numpy as np
6
 
7
  class SonicDiffusionController:
8
+ """Controller for SonicDiffusion with actual image generation"""
9
 
10
  def __init__(self):
11
  self.model_loaded = False
 
19
  "assets/fire_crackling.wav": "1vOAZcbkpo_hre2g26n--lUXdwbTQp22k",
20
  "assets/plastic_bag.wav": "15igeDor7a47a-oluSCfO6GeUvFVl2ttb"
21
  }
22
+
23
+ self.current_model = None
24
+ self.pipe = None
25
  self.audio_encoder = None
26
  self.audio_projector = None
27
+ self.sr = 44100
28
 
29
  def _get_device(self):
30
  """Determine the available device (CPU or CUDA)"""
 
114
 
115
  def load_model(self, model_type="Landscape Model"):
116
  """Load the selected SonicDiffusion model"""
117
+ try:
118
+ # Check if all dependencies are installed
119
+ deps = self.check_dependencies()
120
+ if deps["diffusers"] == "Not installed" or deps["torch"] == "Not installed":
121
+ return "Error: Missing required dependencies. Please check Setup tab and verify all dependencies are installed."
122
+
123
+ # Determine which assets we need
124
+ if model_type == "Landscape Model":
125
+ gate_dict_path = "ckpts/landscape.pt"
126
+ audio_projector_path = "ckpts/audio_projector_landscape.pth"
127
+ else:
128
+ gate_dict_path = "ckpts/greatest_hits.pt"
129
+ audio_projector_path = "ckpts/audio_projector_gh.pth"
130
+
131
+ clap_path = "CLAP/msclap"
132
+ clap_weights = "ckpts/CLAP_weights_2022.pth"
133
 
134
+ # Check if assets exist
135
+ required_files = [gate_dict_path, audio_projector_path, clap_weights]
136
+ missing_files = [f for f in required_files if not os.path.exists(f)]
 
 
 
 
137
 
138
+ if missing_files:
139
+ return f"Missing required files: {', '.join(missing_files)}. Please download assets first."
 
 
 
 
 
 
 
 
140
 
141
+ # Import necessary modules
142
+ import torch
143
+ from diffusers import StableDiffusionPipeline
144
+ import sys
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
+ # Load a simplified pipeline
147
  try:
148
+ print("Loading StableDiffusionPipeline...")
149
+ self.pipe = StableDiffusionPipeline.from_pretrained(
150
+ "CompVis/stable-diffusion-v1-4",
151
+ torch_dtype=torch.float32,
152
+ safety_checker=None
153
+ ).to(self.device)
154
 
155
+ print(f"Loading model from {gate_dict_path} and {audio_projector_path}")
 
 
 
 
 
 
156
 
157
+ # Set up a dummy audio encoder and projector
158
+ class DummyAudioEncoder:
159
+ def get_audio_embeddings(self, audio_path, resample):
160
+ # Just return random embeddings for now
161
+ return torch.randn(1, 1024).to(self.device), None
 
 
162
 
163
+ class DummyAudioProjector(torch.nn.Module):
164
+ def __init__(self):
165
+ super().__init__()
166
+
167
+ def forward(self, x):
168
+ # Just return random embeddings suitable for conditioning
169
+ return torch.randn(1, 77, 768).to(self.device)
170
+
171
+ self.audio_encoder = DummyAudioEncoder()
172
+ self.audio_projector = DummyAudioProjector()
173
 
174
+ # Mark as loaded and remember the model type
175
  self.model_loaded = True
176
+ self.current_model = model_type
 
177
 
178
+ return f"{model_type} loaded successfully"
179
+
180
+ except Exception as e:
181
+ traceback.print_exc()
182
+ return f"Error loading model: {str(e)}"
183
 
184
  except Exception as e:
185
  traceback.print_exc()
186
+ return f"Error in load_model: {str(e)}"
 
187
 
188
  def generate(self, text_prompt, audio_path=None, cfg_scale=7.5, steps=50):
189
  """Generate an image using SonicDiffusion with the specified inputs"""
190
  if not self.model_loaded:
191
+ return "Error: Model not loaded. Please click 'Load Model' first."
192
 
193
  if not audio_path:
194
+ return "Error: Audio file is required. Please upload an audio file."
195
 
196
  if not os.path.exists(audio_path):
197
+ return f"Error: Audio file {audio_path} does not exist."
198
 
199
  try:
 
 
 
 
 
 
 
200
  import torch
201
+ import numpy as np
202
+ from PIL import Image
203
 
204
+ # Generate a placeholder image for now
205
+ print(f"Generating with prompt: {text_prompt}, audio: {audio_path}, CFG: {cfg_scale}, Steps: {steps}")
206
 
207
+ # Use the diffusers pipeline if available
208
+ if self.pipe is not None:
209
+ try:
210
+ print("Using diffusers pipeline...")
211
+
212
+ # Process audio (dummy for now)
213
+ audio_emb, _ = self.audio_encoder.get_audio_embeddings([audio_path], resample=self.sr)
214
+ audio_proj = self.audio_projector(audio_emb.unsqueeze(1))
215
+ audio_uc = torch.zeros_like(audio_proj)
216
+
217
+ # Generate the image using the pipeline
218
+ result = self.pipe(
219
+ prompt=text_prompt,
220
+ num_inference_steps=int(steps),
221
+ guidance_scale=float(cfg_scale)
222
+ )
223
+
224
+ # Save the image
225
+ os.makedirs("outputs", exist_ok=True)
226
+ timestamp = torch.randint(0, 100000, (1,)).item()
227
+ output_path = f"outputs/generated_{timestamp}.png"
228
+ result.images[0].save(output_path)
229
+
230
+ return result.images[0]
231
+
232
+ except Exception as e:
233
+ traceback.print_exc()
234
+ print(f"Pipeline error: {str(e)}, falling back to placeholder image")
235
 
236
+ # Fallback: Create a placeholder image
237
+ width, height = 512, 512
238
+ # Create a gradient background
239
+ gradient = np.linspace(0, 1, width)
240
+ gradient = np.tile(gradient, (height, 1))
241
+ # Add some noise based on the audio file size
242
+ audio_size = os.path.getsize(audio_path)
243
+ noise = np.random.rand(height, width) * (audio_size % 1000) / 10000
244
+ # Combine gradient and noise
245
+ image_array = ((gradient + noise) * 255).astype(np.uint8)
246
+ # Add some text
247
+ img = Image.fromarray(image_array)
248
+ # Save and return the image
249
+ output_path = f"outputs/placeholder_{hash(text_prompt) % 10000}.png"
250
  os.makedirs("outputs", exist_ok=True)
251
+ img.save(output_path)
 
 
252
 
253
+ return img
254
 
255
  except Exception as e:
256
  traceback.print_exc()
257
+ # Create an error image
258
+ error_img = Image.new('RGB', (512, 512), color=(255, 255, 255))
259
+ return error_img