alpercagann commited on
Commit
fb422b4
·
1 Parent(s): f5903f4

Implement simplified SonicDiffusion model components

Browse files
Files changed (1) hide show
  1. controller.py +228 -29
controller.py CHANGED
@@ -1,9 +1,11 @@
1
  import os
2
  import sys
3
  import traceback
 
 
4
 
5
  class SonicDiffusionController:
6
- """Controller for SonicDiffusion with asset downloading support"""
7
 
8
  def __init__(self):
9
  self.model_loaded = False
@@ -17,6 +19,10 @@ class SonicDiffusionController:
17
  "assets/fire_crackling.wav": "1vOAZcbkpo_hre2g26n--lUXdwbTQp22k",
18
  "assets/plastic_bag.wav": "15igeDor7a47a-oluSCfO6GeUvFVl2ttb"
19
  }
 
 
 
 
20
 
21
  def _get_device(self):
22
  """Determine the available device (CPU or CUDA)"""
@@ -106,6 +112,9 @@ class SonicDiffusionController:
106
 
107
  def load_model(self, model_type="Landscape Model"):
108
  """Load the selected SonicDiffusion model"""
 
 
 
109
  if model_type not in ["Landscape Model", "Greatest Hits Model"]:
110
  return f"Unknown model type: {model_type}"
111
 
@@ -117,7 +126,6 @@ class SonicDiffusionController:
117
  gate_dict_path = "ckpts/greatest_hits.pt"
118
  audio_projector_path = "ckpts/audio_projector_gh.pth"
119
 
120
- clap_path = "CLAP/msclap"
121
  clap_weights = "ckpts/CLAP_weights_2022.pth"
122
 
123
  # Check if assets exist
@@ -126,54 +134,245 @@ class SonicDiffusionController:
126
 
127
  if missing_files:
128
  # Download missing files
 
 
 
129
  for file_path in missing_files:
130
  if file_path in self.required_assets:
131
  try:
132
  from download_assets import download_gdrive_file
133
- download_gdrive_file(self.required_assets[file_path], file_path)
 
134
  except Exception as e:
135
- return f"Failed to download {file_path}: {str(e)}"
 
136
  else:
137
- return f"Missing required file {file_path} and no download source available"
 
138
 
139
  try:
140
- # Simple test of loading the model components
141
- import torch
142
-
143
- # Load a small test tensor to verify PyTorch works
144
- self.test_tensor = torch.rand(3, 3).to(self.device)
145
 
146
- # Just check if we can access the file
147
- with open(gate_dict_path, 'rb') as f:
148
- # Just read a small part to verify the file exists and is readable
149
- f.read(10)
150
 
151
- with open(audio_projector_path, 'rb') as f:
152
- f.read(10)
 
 
 
 
 
153
 
154
- with open(clap_weights, 'rb') as f:
155
- f.read(10)
156
-
157
- # For now, just mark as loaded - we'll implement real loading later
158
- self.model_loaded = True
159
- self.model_type = model_type
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
- return f"{model_type} files verified and accessible"
162
 
163
  except Exception as e:
164
  traceback.print_exc()
165
- return f"Error loading model: {str(e)}"
 
166
 
167
  def generate(self, text_prompt, audio_path=None, cfg_scale=7.5, steps=50):
168
  """Generate an image using SonicDiffusion with the specified inputs"""
169
  if not self.model_loaded:
170
- return "Error: Model not loaded. Please click 'Load Model' first."
171
 
172
  if not audio_path:
173
- return "Error: Audio file is required"
174
 
175
  if not os.path.exists(audio_path):
176
- return f"Error: Audio file {audio_path} does not exist"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
- # Return info about what would be generated
179
- return f"Would generate image with:\nModel: {self.model_type}\nPrompt: {text_prompt}\nAudio: {audio_path}\nCFG Scale: {cfg_scale}\nSteps: {steps}\n\nFull implementation coming soon!"
 
1
  import os
2
  import sys
3
  import traceback
4
+ from PIL import Image, ImageDraw, ImageFont
5
+ import numpy as np
6
 
7
  class SonicDiffusionController:
8
+ """Controller for SonicDiffusion with simplified model handling"""
9
 
10
  def __init__(self):
11
  self.model_loaded = False
 
19
  "assets/fire_crackling.wav": "1vOAZcbkpo_hre2g26n--lUXdwbTQp22k",
20
  "assets/plastic_bag.wav": "15igeDor7a47a-oluSCfO6GeUvFVl2ttb"
21
  }
22
+ self.model_type = None
23
+ self.audio_encoder = None
24
+ self.audio_projector = None
25
+ self.pipeline = None
26
 
27
  def _get_device(self):
28
  """Determine the available device (CPU or CUDA)"""
 
112
 
113
  def load_model(self, model_type="Landscape Model"):
114
  """Load the selected SonicDiffusion model"""
115
+ status_messages = []
116
+ status_messages.append(f"Loading {model_type}...")
117
+
118
  if model_type not in ["Landscape Model", "Greatest Hits Model"]:
119
  return f"Unknown model type: {model_type}"
120
 
 
126
  gate_dict_path = "ckpts/greatest_hits.pt"
127
  audio_projector_path = "ckpts/audio_projector_gh.pth"
128
 
 
129
  clap_weights = "ckpts/CLAP_weights_2022.pth"
130
 
131
  # Check if assets exist
 
134
 
135
  if missing_files:
136
  # Download missing files
137
+ status_messages.append(f"Missing files: {', '.join(missing_files)}")
138
+ status_messages.append("Downloading missing files...")
139
+
140
  for file_path in missing_files:
141
  if file_path in self.required_assets:
142
  try:
143
  from download_assets import download_gdrive_file
144
+ success = download_gdrive_file(self.required_assets[file_path], file_path)
145
+ status_messages.append(f"Downloaded {file_path}: {'Success' if success else 'Failed'}")
146
  except Exception as e:
147
+ status_messages.append(f"Failed to download {file_path}: {str(e)}")
148
+ return "\n".join(status_messages)
149
  else:
150
+ status_messages.append(f"Missing required file {file_path} and no download source available")
151
+ return "\n".join(status_messages)
152
 
153
  try:
154
+ # Verify file availability
155
+ for file_path in required_files:
156
+ if not os.path.exists(file_path):
157
+ status_messages.append(f"Required file {file_path} still missing after download attempt")
158
+ return "\n".join(status_messages)
159
 
160
+ # Simple loading of the model components
161
+ try:
162
+ import torch
163
+ status_messages.append("✓ PyTorch available")
164
 
165
+ # Load audio encoder stub
166
+ try:
167
+ self.audio_encoder = SimpleCLAPWrapper(clap_weights)
168
+ status_messages.append("✓ CLAP encoder initialized")
169
+ except Exception as e:
170
+ status_messages.append(f"✗ CLAP encoder error: {str(e)}")
171
+ return "\n".join(status_messages)
172
 
173
+ # Load audio projector stub
174
+ try:
175
+ self.audio_projector = SimpleAudioProjector(audio_projector_path, self.device)
176
+ status_messages.append("✓ Audio projector initialized")
177
+ except Exception as e:
178
+ status_messages.append(f"✗ Audio projector error: {str(e)}")
179
+ return "\n".join(status_messages)
180
+
181
+ # Load pipeline stub
182
+ try:
183
+ self.pipeline = SimpleDiffusionPipeline(gate_dict_path, self.device)
184
+ status_messages.append("✓ Diffusion pipeline initialized")
185
+ except Exception as e:
186
+ status_messages.append(f"✗ Diffusion pipeline error: {str(e)}")
187
+ return "\n".join(status_messages)
188
+
189
+ self.model_loaded = True
190
+ self.model_type = model_type
191
+ status_messages.append(f"✓ {model_type} loaded successfully!")
192
+
193
+ except ImportError as e:
194
+ status_messages.append(f"Error importing required libraries: {str(e)}")
195
+ return "\n".join(status_messages)
196
 
197
+ return "\n".join(status_messages)
198
 
199
  except Exception as e:
200
  traceback.print_exc()
201
+ status_messages.append(f"Error loading model: {str(e)}")
202
+ return "\n".join(status_messages)
203
 
204
  def generate(self, text_prompt, audio_path=None, cfg_scale=7.5, steps=50):
205
  """Generate an image using SonicDiffusion with the specified inputs"""
206
  if not self.model_loaded:
207
+ return self._create_error_image("Model not loaded. Please click 'Load Model' first.")
208
 
209
  if not audio_path:
210
+ return self._create_error_image("Audio file is required")
211
 
212
  if not os.path.exists(audio_path):
213
+ return self._create_error_image(f"Audio file {audio_path} does not exist")
214
+
215
+ try:
216
+ # Process audio through CLAP encoder
217
+ audio_emb = self.audio_encoder.get_audio_embeddings(audio_path)
218
+
219
+ # Process through audio projector
220
+ audio_proj = self.audio_projector(audio_emb)
221
+
222
+ # Create unconditional embedding
223
+ import torch
224
+ audio_emb_zero = torch.zeros(1, 1024).to(self.device)
225
+ audio_uc = self.audio_projector(audio_emb_zero)
226
+
227
+ # Combine for context
228
+ audio_context = torch.cat([audio_uc, audio_proj]).to(self.device)
229
+
230
+ # Generate image
231
+ image = self.pipeline.generate(
232
+ prompt=text_prompt,
233
+ audio_context=audio_context,
234
+ guidance_scale=cfg_scale,
235
+ num_inference_steps=steps
236
+ )
237
+
238
+ # Save the generated image
239
+ os.makedirs("outputs", exist_ok=True)
240
+ timestamp = self._get_timestamp()
241
+ output_path = f"outputs/generated_{timestamp}.png"
242
+ image.save(output_path)
243
+
244
+ return image
245
+
246
+ except Exception as e:
247
+ traceback.print_exc()
248
+ return self._create_error_image(f"Error during generation: {str(e)}")
249
+
250
+ def _create_error_image(self, error_message):
251
+ """Create an error image with the provided message"""
252
+ img = Image.new('RGB', (512, 512), color=(255, 255, 255))
253
+ draw = ImageDraw.Draw(img)
254
+
255
+ # Draw a red border
256
+ draw.rectangle([(0, 0), (511, 511)], outline=(255, 0, 0), width=5)
257
+
258
+ # Draw the error message
259
+ draw.text((20, 240), f"Error: {error_message}", fill=(0, 0, 0))
260
+
261
+ return img
262
+
263
+ def _get_timestamp(self):
264
+ """Get current timestamp in string format"""
265
+ from datetime import datetime
266
+ return datetime.now().strftime("%Y%m%d_%H%M%S")
267
+
268
+
269
+ # Simplified model components for demonstration
270
+ class SimpleCLAPWrapper:
271
+ """Simplified CLAP wrapper for audio encoding"""
272
+
273
+ def __init__(self, weights_path):
274
+ self.weights_path = weights_path
275
+ self.sr = 44100
276
+
277
+ # Just check if the weights file exists
278
+ if not os.path.exists(weights_path):
279
+ raise ValueError(f"CLAP weights file not found: {weights_path}")
280
+
281
+ def get_audio_embeddings(self, audio_path):
282
+ """Generate audio embeddings from the audio file"""
283
+ import torch
284
+ import librosa
285
+
286
+ # Load the audio file
287
+ try:
288
+ audio, _ = librosa.load(audio_path, sr=self.sr, mono=True)
289
+ except Exception as e:
290
+ raise ValueError(f"Error loading audio file {audio_path}: {str(e)}")
291
+
292
+ # Create a simple random embedding (since we don't have the real model)
293
+ # This would normally be generated by the CLAP model
294
+ torch.manual_seed(hash(audio_path) % 2**32)
295
+ embedding = torch.randn(1, 1024)
296
+
297
+ return embedding
298
+
299
+
300
+ class SimpleAudioProjector:
301
+ """Simplified audio projector for audio embedding processing"""
302
+
303
+ def __init__(self, weights_path, device):
304
+ self.weights_path = weights_path
305
+ self.device = device
306
+
307
+ # Just check if the weights file exists
308
+ if not os.path.exists(weights_path):
309
+ raise ValueError(f"Audio projector weights file not found: {weights_path}")
310
+
311
+ def __call__(self, audio_embedding):
312
+ """Process audio embeddings"""
313
+ import torch
314
+
315
+ # Create a simple transformation (since we don't have the real model)
316
+ # This would normally be processed by the audio projector model
317
+ torch.manual_seed(42)
318
+ projection = torch.randn(1, 77, 768).to(self.device)
319
+
320
+ return projection
321
+
322
+
323
+ class SimpleDiffusionPipeline:
324
+ """Simplified diffusion pipeline for image generation"""
325
+
326
+ def __init__(self, weights_path, device):
327
+ self.weights_path = weights_path
328
+ self.device = device
329
+
330
+ # Just check if the weights file exists
331
+ if not os.path.exists(weights_path):
332
+ raise ValueError(f"Pipeline weights file not found: {weights_path}")
333
+
334
+ def generate(self, prompt, audio_context, guidance_scale=7.5, num_inference_steps=50):
335
+ """Generate an image based on the prompt and audio context"""
336
+ # Create a simple visualization of the audio context and prompt
337
+ return self._create_visualized_output(prompt, audio_context, guidance_scale, num_inference_steps)
338
+
339
+ def _create_visualized_output(self, prompt, audio_context, guidance_scale, num_inference_steps):
340
+ """Create a visualization of the generation parameters"""
341
+ import torch
342
+ import numpy as np
343
+ from PIL import Image, ImageDraw, ImageFont
344
+
345
+ # Create a gradient background based on the audio context tensor
346
+ # This is just for visualization since we don't have the real model
347
+ audio_data = audio_context[1].detach().cpu().mean(dim=1).numpy()
348
+ audio_data = (audio_data - audio_data.min()) / (audio_data.max() - audio_data.min())
349
+
350
+ # Create a visualization
351
+ img = Image.new('RGB', (512, 512), color=(255, 255, 255))
352
+ draw = ImageDraw.Draw(img)
353
+
354
+ # Draw a color gradient based on audio (simplified visualization)
355
+ for y in range(512):
356
+ # Get color from audio data
357
+ idx = int(y / 512 * len(audio_data))
358
+ if idx >= len(audio_data):
359
+ idx = len(audio_data) - 1
360
+
361
+ val = audio_data[idx]
362
+ r = int(255 * (1 - val))
363
+ g = int(200 * val)
364
+ b = int(255 * (0.5 + 0.5 * val))
365
+
366
+ draw.line([(0, y), (512, y)], fill=(r, g, b))
367
+
368
+ # Add the prompt text
369
+ draw.rectangle([(10, 10), (502, 90)], fill=(255, 255, 255, 180))
370
+ draw.text((20, 20), f"Prompt: {prompt}", fill=(0, 0, 0))
371
+ draw.text((20, 40), f"CFG Scale: {guidance_scale}", fill=(0, 0, 0))
372
+ draw.text((20, 60), f"Steps: {num_inference_steps}", fill=(0, 0, 0))
373
+
374
+ # Add "Generated Image" label
375
+ draw.rectangle([(10, 470), (502, 502)], fill=(255, 255, 255, 180))
376
+ draw.text((20, 480), "Generated Image (Simulation)", fill=(0, 0, 0))
377
 
378
+ return img