alpercagann commited on
Commit
b4447cb
·
1 Parent(s): b76b715

Create minimal app without torch dependency

Browse files
Files changed (1) hide show
  1. app.py +28 -447
app.py CHANGED
@@ -1,457 +1,38 @@
1
- # app.py
2
  import os
3
  import sys
4
- import subprocess
5
- import gradio as gr
6
- import torch
7
- import traceback
8
- from datetime import datetime
9
-
10
- # Ensure required packages are installed
11
- try:
12
- import requests
13
- import tqdm
14
- import re
15
- except ImportError:
16
- subprocess.check_call([sys.executable, "-m", "pip", "install", "requests", "tqdm"])
17
- import requests
18
- import tqdm
19
- import re
20
-
21
- # Asset management functions
22
- def get_gdrive_file_id(url):
23
- """Extract file ID from Google Drive URL"""
24
- match = re.search(r"d/([a-zA-Z0-9_-]+)", url) or re.search(r"id=([a-zA-Z0-9_-]+)", url)
25
- if match:
26
- return match.group(1)
27
- return None
28
-
29
- def download_gdrive_file(file_id, destination):
30
- """Download a file from Google Drive with support for large files"""
31
- if os.path.exists(destination):
32
- print(f"File already exists: {destination}")
33
- return True
34
-
35
- # Make the directory if it doesn't exist
36
- os.makedirs(os.path.dirname(destination), exist_ok=True)
37
-
38
- # First, try the direct download URL
39
- url = f"https://drive.google.com/uc?export=download&id={file_id}"
40
-
41
- # Set up a session to handle cookies
42
- session = requests.Session()
43
-
44
- # First request to get the confirmation token for large files
45
- response = session.get(url, stream=True)
46
-
47
- # Check if there's a download confirmation page
48
- if "confirm" in response.url:
49
- # Extract confirmation token
50
- token = response.url.split("confirm=")[1].split("&")[0]
51
- url = f"{url}&confirm={token}"
52
- response = session.get(url, stream=True)
53
-
54
- # Get file size for progress bar
55
- total_size = int(response.headers.get('content-length', 0))
56
-
57
- # Download the file with progress bar
58
- print(f"Downloading to {destination} ({total_size/(1024*1024):.1f} MB)...")
59
- with open(destination, 'wb') as f:
60
- with tqdm.tqdm(total=total_size, unit='B', unit_scale=True) as pbar:
61
- for chunk in response.iter_content(chunk_size=1024*1024):
62
- if chunk:
63
- f.write(chunk)
64
- pbar.update(len(chunk))
65
-
66
- print(f"Downloaded {destination} successfully!")
67
- return True
68
-
69
- def check_and_download_assets():
70
- """Check if required assets exist and download them if needed"""
71
- # Define required files and their Google Drive URLs
72
- gdrive_urls = {
73
- "assets/fire_crackling.wav": "https://drive.google.com/file/d/1vOAZcbkpo_hre2g26n--lUXdwbTQp22k/view?usp=drive_link",
74
- "assets/plastic_bag.wav": "https://drive.google.com/file/d/15igeDor7a47a-oluSCfO6GeUvFVl2ttb/view?usp=sharing",
75
- "ckpts/landscape.pt": "https://drive.google.com/file/d/1-oTNIjCZq3_mGI1XRfzDyCnmjXCvd0Vh/view?usp=drive_link",
76
- "ckpts/greatest_hits.pt": "https://drive.google.com/file/d/1wGDCB4iRFi4kf7bsFXV3qkc9_jvyNrCa/view?usp=drive_link",
77
- "ckpts/audio_projector_landscape.pth": "https://drive.google.com/file/d/1BdjzRJOC8bvyPgrAkJJcCaN3EEJg3STm/view?usp=sharing",
78
- "ckpts/audio_projector_gh.pth": "https://drive.google.com/file/d/19Uk68PXVOjE3TJl86H-IlMaM1URhU33a/view?usp=sharing",
79
- "ckpts/CLAP_weights_2022.pth": "https://drive.google.com/file/d/1VK22jxHkFwpxknxQBLd6kIgO5WxQdLFP/view?usp=sharing"
80
- }
81
-
82
- # Create necessary directories
83
- os.makedirs("assets", exist_ok=True)
84
- os.makedirs("ckpts", exist_ok=True)
85
-
86
- # Only download missing files
87
- missing_files = {dest: url for dest, url in gdrive_urls.items() if not os.path.exists(dest)}
88
-
89
- if missing_files:
90
- print(f"Missing {len(missing_files)} required files. Downloading...")
91
-
92
- for destination, url in missing_files.items():
93
- file_id = get_gdrive_file_id(url)
94
- if file_id:
95
- try:
96
- download_gdrive_file(file_id, destination)
97
- except Exception as e:
98
- print(f"Error downloading {destination}: {e}")
99
- return False
100
- else:
101
- print(f"Could not extract file ID from {url}")
102
- return False
103
-
104
- print("All required assets are available!")
105
- return True
106
-
107
-
108
- # SonicDiffusion Controller Class
109
- class SonicDiffusionController:
110
- def __init__(self, device=None):
111
- if device is None:
112
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
113
- else:
114
- self.device = device
115
-
116
- print(f"Using device: {self.device}")
117
- self.sr = 44100
118
- self.model_loaded = False
119
-
120
- def load_model(self,
121
- gate_dict_path="ckpts/landscape.pt",
122
- clap_path="CLAP/msclap",
123
- clap_weights="ckpts/CLAP_weights_2022.pth",
124
- adapter_ckpt_path="ckpts/audio_projector_landscape.pth"):
125
- """Load the model conditionally based on environment and availability"""
126
- try:
127
- # First, check if the required files exist
128
- for path in [gate_dict_path, adapter_ckpt_path, clap_weights]:
129
- if not os.path.exists(path):
130
- return f"Error: Required file {path} not found"
131
-
132
- print("Loading models - this may take a moment...")
133
-
134
- # Import here to avoid import errors if files are missing
135
- from unet2d_custom import UNet2DConditionModel
136
- from pipeline_stable_diffusion_custom import StableDiffusionPipeline
137
- from ldm.modules.encoders.audio_projector_res import Adapter
138
-
139
- # Try to load the model with appropriate settings for the hardware
140
- try:
141
- model_id = "CompVis/stable-diffusion-v1-4"
142
- self.unet = UNet2DConditionModel.from_pretrained(
143
- model_id,
144
- subfolder="unet",
145
- use_adapter_list=[False, True, True],
146
- low_cpu_mem_usage=True,
147
- device_map="auto" if self.device == "cuda" else None
148
- )
149
-
150
- self.pipeline = StableDiffusionPipeline.from_pretrained(
151
- model_id,
152
- use_safetensors=True,
153
- torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
154
- )
155
-
156
- # Move models to the appropriate device
157
- self.unet = self.unet.to(self.device)
158
- self.pipeline = self.pipeline.to(self.device)
159
-
160
- except Exception as e:
161
- print(f"Warning: Encountered issue with full model loading: {e}")
162
- print("Trying with simplified loading...")
163
-
164
- # Simplified loading for compatibility
165
- model_id = "CompVis/stable-diffusion-v1-4"
166
- self.unet = UNet2DConditionModel.from_pretrained(
167
- model_id,
168
- subfolder="unet",
169
- use_adapter_list=[False, True, True],
170
- low_cpu_mem_usage=True
171
- ).to(self.device)
172
-
173
- self.pipeline = StableDiffusionPipeline.from_pretrained(
174
- model_id,
175
- use_safetensors=True
176
- ).to(self.device)
177
-
178
- # Load gate dictionary
179
- gate_dict = torch.load(gate_dict_path, map_location=self.device)
180
- for name, param in self.unet.named_parameters():
181
- if "adapter" in name:
182
- param.data = gate_dict[name].to(self.device)
183
-
184
- # Set pipeline's UNet
185
- self.pipeline.unet = self.unet
186
-
187
- # Import and load audio encoder
188
- import sys
189
- sys.path.append(clap_path)
190
- try:
191
- from CLAPWrapper import CLAPWrapper
192
-
193
- self.audio_encoder = CLAPWrapper(clap_weights, use_cuda=(self.device=="cuda"))
194
- self.audio_projector = Adapter(audio_token_count=77, transformer_layer_count=4).to(self.device)
195
- self.audio_projector.load_state_dict(torch.load(adapter_ckpt_path, map_location=self.device))
196
- self.audio_projector.eval()
197
-
198
- self.model_loaded = True
199
- print("Model loaded successfully!")
200
- return "Model loaded successfully"
201
-
202
- except ImportError as e:
203
- return f"Error importing CLAP: {str(e)}. Make sure the CLAP module is available."
204
-
205
- except Exception as e:
206
- error_msg = f"Failed to load model: {str(e)}"
207
- print(error_msg)
208
- traceback.print_exc()
209
- return error_msg
210
-
211
- def generate(self, audio_model=None, audio=None, prompt=None, cfg_scale=5, num_inference_steps=50):
212
- """Generate an image from audio input"""
213
- if not self.model_loaded:
214
- from PIL import Image, ImageDraw
215
- img = Image.new('RGB', (512, 512), color=(255, 255, 255))
216
- d = ImageDraw.Draw(img)
217
- d.text((10, 250), "Error: Model not loaded. Click 'Load Model' first.", fill=(0, 0, 0))
218
- return img
219
-
220
- try:
221
- if audio is None:
222
- raise ValueError("No audio file provided")
223
-
224
- if prompt is None or prompt.strip() == "":
225
- prompt = "a high quality image"
226
-
227
- with torch.no_grad():
228
- # Process audio input
229
- audio_emb, _ = self.audio_encoder.get_audio_embeddings([audio], resample=self.sr)
230
- audio_proj = self.audio_projector(audio_emb.unsqueeze(1))
231
-
232
- # Create unconditional embedding
233
- audio_emb = torch.zeros(1, 1024).to(self.device)
234
- audio_uc = self.audio_projector(audio_emb.unsqueeze(1))
235
-
236
- # Combine for context
237
- audio_context = torch.cat([audio_uc, audio_proj]).to(self.device)
238
-
239
- # Generate image
240
- print(f"Generating image with prompt: '{prompt}', CFG: {cfg_scale}, Steps: {num_inference_steps}")
241
- image = self.pipeline(
242
- prompt=prompt,
243
- audio_context=audio_context,
244
- guidance_scale=cfg_scale,
245
- num_inference_steps=num_inference_steps
246
- )
247
-
248
- # Save a copy of the generated image
249
- os.makedirs("outputs", exist_ok=True)
250
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
251
- output_path = f"outputs/generated_{timestamp}.png"
252
- image.images[0].save(output_path)
253
- print(f"Image saved to {output_path}")
254
-
255
- return image.images[0]
256
-
257
- except Exception as e:
258
- error_msg = f"Error in generation: {str(e)}"
259
- print(error_msg)
260
- traceback.print_exc()
261
-
262
- # Return a blank error image
263
- from PIL import Image, ImageDraw
264
- img = Image.new('RGB', (512, 512), color=(255, 255, 255))
265
- d = ImageDraw.Draw(img)
266
- d.text((10, 250), f"Error: {str(e)}", fill=(0, 0, 0))
267
- return img
268
-
269
- def update_audio_model(self, audio_model_update):
270
- """Update audio model based on selection"""
271
- try:
272
- if not self.model_loaded:
273
- return "Error: Model not loaded. Click 'Load Model' first."
274
-
275
- if audio_model_update == "Landscape Model":
276
- audio_projector_path = "ckpts/audio_projector_landscape.pth"
277
- gate_dict_path = "ckpts/landscape.pt"
278
- else:
279
- audio_projector_path = "ckpts/audio_projector_gh.pth"
280
- gate_dict_path = "ckpts/greatest_hits.pt"
281
-
282
- # Check if files exist
283
- if not os.path.exists(audio_projector_path) or not os.path.exists(gate_dict_path):
284
- return f"Error: Required model files not found. Need {audio_projector_path} and {gate_dict_path}"
285
-
286
- # Load gate dictionary and update parameters
287
- gate_dict = torch.load(gate_dict_path, map_location=self.device)
288
- for name, param in self.pipeline.unet.named_parameters():
289
- if "adapter" in name:
290
- param.data = gate_dict[name].to(self.device)
291
-
292
- # Load audio projector state
293
- self.audio_projector.load_state_dict(torch.load(audio_projector_path, map_location=self.device))
294
-
295
- return f"Model updated to {audio_model_update}"
296
- except Exception as e:
297
- error_msg = f"Error updating audio model: {str(e)}"
298
- print(error_msg)
299
- return error_msg
300
 
 
 
 
 
 
301
 
302
- # CSS for styling the UI
303
- css = """
304
- .gradio-container {
305
- font-family: 'IBM Plex Sans', sans-serif;
306
- }
307
- .toolbutton {
308
- margin-bottom: 0em;
309
- max-width: 2em;
310
- min-width: 2em !important;
311
- height: 2em;
312
- }
313
- .output-image {
314
- border-radius: 0.5rem;
315
- border: 1px solid #cccccc;
316
- }
317
- .info-text {
318
- font-size: 14px;
319
- color: #666;
320
- margin-top: 5px;
321
- }
322
- """
323
 
324
- # Initialize controller
325
- controller = SonicDiffusionController()
 
 
326
 
327
- def ui():
328
- with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
329
- gr.Markdown(
330
- """
331
- # 🎵 SonicDiffusion: Audio-Driven Image Generation
332
-
333
- Upload an audio file and enter a prompt to generate audio-conditioned images.
334
-
335
- *This model transforms audio characteristics into visual elements.*
336
- """
337
- )
338
-
339
- with gr.Row():
340
- with gr.Column(scale=1):
341
- # Left column - inputs
342
- gr.Markdown("### Model Controls")
343
-
344
- # Load model button - explicitly load the model when ready
345
- load_model_button = gr.Button(value="1️⃣ Load Model (click first)", variant='primary')
346
-
347
- with gr.Accordion("Model Selection", open=True):
348
- audio_model_dropdown = gr.Dropdown(
349
- label="Select SonicDiffusion model",
350
- value="Landscape Model",
351
- choices=["Landscape Model", "Greatest Hits Model"],
352
- interactive=True,
353
- )
354
- model_info = gr.Markdown("""
355
- **Landscape Model**: Optimized for nature and environment sounds
356
-
357
- **Greatest Hits**: Better with music and rhythmic sounds
358
- """)
359
-
360
- # Audio input
361
- audio_input = gr.Audio(label="2️⃣ Upload or Record Audio", sources=["upload", "microphone"], type="filepath")
362
-
363
- # Prompt input
364
- prompt_textbox = gr.Textbox(label="3️⃣ Enter Prompt", lines=2, placeholder="Describe the image you want to generate...")
365
-
366
- with gr.Accordion("Advanced Settings", open=False):
367
- # Generation parameters
368
- with gr.Row():
369
- cfg_scale_slider = gr.Slider(label="Guidance Scale", value=7.5, minimum=1.0, maximum=20.0, info="Higher values = more prompt adherence")
370
- num_steps_slider = gr.Slider(label="Inference Steps", value=50, minimum=20, maximum=100, step=5, info="Higher values = more detail, slower generation")
371
-
372
- # Generate button
373
- generate_button = gr.Button(value="4️⃣ Generate Image", variant='primary', size="lg")
374
-
375
- # Status indicator
376
- status_text = gr.Textbox(label="Status", value="Click 'Load Model' to begin")
377
-
378
- gr.Markdown("### Example Audio Files")
379
- with gr.Row():
380
- examples = [
381
- ['./assets/fire_crackling.wav'],
382
- ['./assets/plastic_bag.wav'],
383
- ]
384
- gr.Examples(examples=examples, inputs=[audio_input])
385
-
386
- with gr.Column(scale=1):
387
- # Right column - output
388
- gr.Markdown("### Generated Image")
389
- output = gr.Image(label="Output Image", height=512, width=512)
390
- download_btn = gr.Button("💾 Download Image")
391
- output_info = gr.Markdown("""
392
- *Generated images are also automatically saved to the 'outputs' folder.*
393
-
394
- #### How SonicDiffusion Works
395
-
396
- SonicDiffusion extracts features from audio files and uses them to condition a Stable Diffusion model.
397
- The audio influences how the image is generated, with different sounds creating different visual effects.
398
-
399
- Try experimenting with different audio files and prompts!
400
- """)
401
-
402
- # Event handlers
403
- load_model_button.click(
404
- fn=controller.load_model,
405
- inputs=[],
406
- outputs=[status_text]
407
- )
408
-
409
- audio_model_dropdown.change(
410
- fn=controller.update_audio_model,
411
- inputs=[audio_model_dropdown],
412
- outputs=[status_text]
413
- )
414
-
415
- generate_button.click(
416
- fn=controller.generate,
417
- inputs=[
418
- audio_model_dropdown,
419
- audio_input,
420
- prompt_textbox,
421
- cfg_scale_slider,
422
- num_steps_slider,
423
- ],
424
- outputs=[output]
425
- )
426
-
427
- download_btn.click(
428
- fn=lambda x: x,
429
- inputs=[output],
430
- outputs=[output],
431
- _js="(img) => { if(img) { const a = document.createElement('a'); a.href = img; a.download = 'sonicDiffusion_' + Date.now() + '.png'; a.click(); } return img; }"
432
- )
433
-
434
- return demo
435
 
436
  if __name__ == "__main__":
437
- # Create necessary directories
438
- os.makedirs("assets", exist_ok=True)
439
- os.makedirs("ckpts", exist_ok=True)
440
- os.makedirs("outputs", exist_ok=True)
441
-
442
- # Check environment
443
- print(f"Python version: {sys.version}")
444
- print(f"PyTorch version: {torch.__version__}")
445
- print(f"CUDA available: {torch.cuda.is_available()}")
446
- if torch.cuda.is_available():
447
- print(f"CUDA device: {torch.cuda.get_device_name(0)}")
448
-
449
- # Check and download assets if needed
450
- print("Checking required assets...")
451
- assets_ready = check_and_download_assets()
452
- if not assets_ready:
453
- print("Warning: Could not download all required assets. The app may not function correctly.")
454
 
455
  # Launch the demo
456
- demo = ui()
457
- demo.launch(share=True)
 
1
+ # Minimal app.py that doesn't require torch
2
  import os
3
  import sys
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
+ # Print environment information for debugging
6
+ print("==== Environment Information ====")
7
+ print(f"Python version: {sys.version}")
8
+ print(f"Working directory: {os.getcwd()}")
9
+ print(f"Directory contents: {os.listdir('.')}")
10
 
11
+ # Simple Gradio interface
12
+ import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ def hello(name):
15
+ if not name:
16
+ name = "World"
17
+ return f"Hello, {name}!"
18
 
19
+ # Create a simple Gradio interface
20
+ demo = gr.Interface(
21
+ fn=hello,
22
+ inputs="text",
23
+ outputs="text",
24
+ title="SonicDiffusion - Setup Test",
25
+ description="This is a test app to verify the environment is working."
26
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  if __name__ == "__main__":
29
+ # Try to print installed packages
30
+ try:
31
+ import subprocess
32
+ print("==== Installed Packages ====")
33
+ subprocess.run([sys.executable, "-m", "pip", "list"])
34
+ except Exception as e:
35
+ print(f"Error listing packages: {e}")
 
 
 
 
 
 
 
 
 
 
36
 
37
  # Launch the demo
38
+ demo.launch()