mclemcrew commited on
Commit
edb4194
·
1 Parent(s): 265c639
Files changed (3) hide show
  1. app.py +262 -459
  2. requirements.txt +2 -2
  3. setup_examples.py +0 -52
app.py CHANGED
@@ -2,14 +2,16 @@ import os
2
  import gradio as gr
3
  import spaces
4
  import torch
5
- from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration, BitsAndBytesConfig
6
  import numpy as np
7
  import librosa
8
- from urllib.request import urlopen
9
- from io import BytesIO
10
  import logging
11
  import sys
12
  import gc
 
 
 
 
 
13
 
14
  # Configure logging
15
  logging.basicConfig(
@@ -19,19 +21,22 @@ logging.basicConfig(
19
  )
20
  logger = logging.getLogger(__name__)
21
 
22
- # Update to use the merged model
23
- MODEL_ID = "mclemcrew/Qwen-Audio-Mix-Instruct"
24
 
25
- # Cache for model and processor
26
  model_cache = None
27
  processor_cache = None
28
 
29
- # Memory tracking
30
  def log_gpu_memory(message=""):
 
31
  if torch.cuda.is_available():
32
  allocated = torch.cuda.memory_allocated() / 1024**3
33
  reserved = torch.cuda.memory_reserved() / 1024**3
34
  logger.info(f"{message} - GPU Memory: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved")
 
 
35
 
36
  def load_model():
37
  """Load the fine-tuned model with optimized memory usage"""
@@ -45,50 +50,51 @@ def load_model():
45
  # Log initial GPU state
46
  log_gpu_memory("Before model loading")
47
 
48
- # Load processor
49
- logger.info(f"Loading processor from {MODEL_ID}")
50
- processor = AutoProcessor.from_pretrained(MODEL_ID)
51
-
52
- # Clean up memory
53
  gc.collect()
54
  if torch.cuda.is_available():
55
  torch.cuda.empty_cache()
56
 
57
- # Define proper quantization config - using 4-bit quantization
 
 
 
 
 
 
 
 
 
58
  quant_config = BitsAndBytesConfig(
59
  load_in_4bit=True,
60
- bnb_4bit_compute_dtype=torch.float16, # Match training dtype
61
  bnb_4bit_use_double_quant=True,
62
- bnb_4bit_quant_type="nf4"
 
63
  )
64
 
 
 
65
  try:
66
- logger.info("Loading model with optimized 4-bit quantization")
67
-
68
- # Load with quantization and offloading for memory efficiency
69
  model = Qwen2AudioForConditionalGeneration.from_pretrained(
70
  MODEL_ID,
71
  quantization_config=quant_config,
72
- device_map="auto",
73
  torch_dtype=torch.float16,
74
- offload_folder="offload",
75
- offload_state_dict=True,
76
  low_cpu_mem_usage=True
77
  )
78
-
79
- log_gpu_memory("After optimized model loading")
80
- logger.info("Model loaded successfully with optimized approach")
81
  except Exception as e:
82
- logger.error(f"Error loading model with optimized approach: {e}")
 
83
  gc.collect()
84
  if torch.cuda.is_available():
85
  torch.cuda.empty_cache()
86
 
 
87
  try:
88
- # Fallback to 8-bit quantization (more stable but less compression)
89
- logger.info("Attempting 8-bit quantization fallback")
90
-
91
- from transformers import BitsAndBytesConfig
92
  quant_config_8bit = BitsAndBytesConfig(
93
  load_in_8bit=True,
94
  llm_int8_threshold=6.0
@@ -97,187 +103,226 @@ def load_model():
97
  model = Qwen2AudioForConditionalGeneration.from_pretrained(
98
  MODEL_ID,
99
  quantization_config=quant_config_8bit,
100
- device_map="auto",
101
  torch_dtype=torch.float16
102
  )
103
-
104
- log_gpu_memory("After 8-bit fallback loading")
105
  logger.info("Model loaded successfully with 8-bit quantization")
106
  except Exception as e2:
107
- logger.error(f"Error loading with 8-bit quantization: {e2}")
 
108
  gc.collect()
109
  if torch.cuda.is_available():
110
  torch.cuda.empty_cache()
111
 
 
112
  try:
113
- # Fallback to dummy model as last resort
114
- logger.warning("Loading dummy placeholder model")
115
- class DummyAudioModel:
116
- def __init__(self):
117
- self.device = torch.device("cpu")
118
- self.dummy_parameters = [torch.tensor([0.0])]
119
-
120
- def generate(self, **kwargs):
121
- input_ids = kwargs.get("input_ids", None)
122
- if input_ids is not None:
123
- batch_size, seq_len = input_ids.shape
124
- dummy_output = torch.ones((batch_size, seq_len + 20), dtype=torch.long)
125
- dummy_output[:, :seq_len] = input_ids
126
- dummy_output[:, seq_len:] = 100
127
- return dummy_output
128
- else:
129
- return torch.ones((1, 30), dtype=torch.long)
130
-
131
- def parameters(self):
132
- return iter(self.dummy_parameters)
133
-
134
- def to(self, device):
135
- self.device = device
136
- return self
137
-
138
- model = DummyAudioModel()
139
- logger.warning("Created dummy model placeholder - no real functionality available")
140
  except Exception as e3:
141
- logger.error(f"Failed to create dummy model: {e3}")
142
- raise RuntimeError(f"Could not load any model version after multiple attempts")
 
 
 
 
 
 
 
143
 
144
  # Cache the model and processor
145
  model_cache = model
146
  processor_cache = processor
147
 
 
 
 
148
  return model, processor
149
 
150
- def process_audio_from_url(audio_url, processor):
151
- """Process audio file from URL for model input with optimized memory usage"""
 
 
 
 
 
 
 
 
 
 
 
152
  try:
153
- logger.info(f"Processing audio from URL: {audio_url}")
154
- # Get processor's sampling rate
155
  target_sr = int(processor.feature_extractor.sampling_rate)
156
  logger.info(f"Target sampling rate: {target_sr}")
157
 
158
- # Audio bytes container
159
- audio_bytes = None
 
160
 
161
- # Handle various URL formats
162
- if audio_url.startswith(('http://', 'https://')):
163
- # For web URLs
164
  try:
165
- import requests
166
- response = requests.get(audio_url)
 
 
 
167
  response.raise_for_status()
168
  audio_bytes = BytesIO(response.content)
169
- # Free memory
170
- del response
171
- except Exception as req_error:
172
- logger.info(f"Requests failed, falling back to urlopen: {req_error}")
173
- audio_bytes = BytesIO(urlopen(audio_url).read())
174
- elif audio_url.startswith('file://'):
175
- # For local file URLs
176
- file_path = audio_url[7:] # Remove 'file://' prefix
177
- with open(file_path, 'rb') as f:
178
- audio_bytes = BytesIO(f.read())
 
179
  else:
180
- # Try as a local file path
181
- with open(audio_url, 'rb') as f:
182
- audio_bytes = BytesIO(f.read())
183
 
184
- # Load and resample audio
185
- audio_data, sr_loaded = librosa.load(audio_bytes, sr=None)
186
  logger.info(f"Audio loaded with shape: {audio_data.shape}, original SR: {sr_loaded}")
187
 
188
- # Free memory
189
- del audio_bytes
190
- gc.collect()
191
-
192
  # Resample if needed
193
  if sr_loaded != target_sr:
194
  logger.info(f"Resampling from {sr_loaded} Hz to {target_sr} Hz")
195
  audio_data = librosa.resample(audio_data, orig_sr=sr_loaded, target_sr=target_sr)
196
 
197
- # Reduce to 15 seconds maximum (was 30 seconds before)
198
- max_seconds = 15
199
- max_samples = max_seconds * target_sr
200
  if len(audio_data) > max_samples:
201
- logger.info(f"Limiting audio to {max_seconds} seconds for memory efficiency")
202
  audio_data = audio_data[:max_samples]
203
 
204
  # Ensure audio is float32
205
  audio_data = audio_data.astype(np.float32)
206
 
 
 
 
207
  return audio_data
 
208
  except Exception as e:
209
- logger.error(f"Error processing audio from URL {audio_url}: {e}", exc_info=True)
210
- return None
211
- finally:
212
- # Clean up any lingering memory
213
- if 'audio_bytes' in locals() and audio_bytes is not None:
214
- del audio_bytes
215
- gc.collect()
216
 
217
- @spaces.GPU(duration=120)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
  def chat_with_model(audio_url, message, chat_history):
219
- """Generate response from the model using an audio URL"""
220
- logger.info(f"Starting chat_with_model with audio_url: {audio_url}, message: {message}")
221
 
222
- # Log initial memory state
223
- log_gpu_memory("At start of chat_with_model")
 
 
 
 
 
 
 
 
224
 
225
- # Validate that audio URL is provided
226
  if not audio_url or not audio_url.strip():
227
- return "⚠️ Please set an audio track URL first before chatting."
 
 
 
228
 
229
  try:
230
- # Load model and processor on demand
231
  model, processor = load_model()
232
 
233
- # Log memory after model load
234
- log_gpu_memory("After model load")
 
 
235
 
236
- # Check if we're using a dummy model
237
- is_dummy = hasattr(model, '__class__') and model.__class__.__name__ == 'DummyAudioModel'
238
-
239
- if is_dummy:
240
- logger.warning("Using dummy model - providing generic response")
241
- return (
242
- "⚠️ I'm currently having trouble analyzing your audio due to technical limitations "
243
- "in this environment. The model requires more GPU memory than is available. "
244
- "Please try a different audio file or contact the developer for assistance."
245
- )
246
-
247
- # Process audio
248
- audios = []
249
- audio_data = process_audio_from_url(audio_url, processor)
250
- if audio_data is not None:
251
- audios.append(audio_data)
252
- else:
253
- return "⚠️ Failed to process audio from the provided URL. Please check that the URL is valid and accessible."
254
 
255
- # Log memory after audio processing
256
- log_gpu_memory("After audio processing")
257
 
258
- # System prompt for the model
259
- SYSTEM_PROMPT = "You are an expert audio engineer assisting with music production and mixing. Provide clear, specific advice on audio engineering techniques, mixing adjustments, and production decisions based on the audio samples and the user's questions. Focus on practical, actionable guidance. Be as specific as possible when answering the user's questions about the mix."
260
-
261
- # Start with system prompt
262
  conversation = [
263
- {"role": "system", "content": SYSTEM_PROMPT}
264
  ]
265
 
266
- # Add chat history - limited to last 5 exchanges to save memory
267
- history_limit = min(len(chat_history), 5)
268
  for user_msg, bot_msg in chat_history[-history_limit:]:
269
  if user_msg:
270
  conversation.append({"role": "user", "content": user_msg})
271
  if bot_msg:
272
  conversation.append({"role": "assistant", "content": bot_msg})
273
 
274
- # Determine if this is the first message with a new audio
275
- is_first_message_with_audio = len(chat_history) == 0
276
 
277
- # Format user message based on whether it's the first message with audio
278
- if is_first_message_with_audio:
279
  # First message includes audio
280
- logger.info("First message with audio, including audio in content")
281
  conversation.append({
282
  "role": "user",
283
  "content": [
@@ -286,333 +331,91 @@ def chat_with_model(audio_url, message, chat_history):
286
  ]
287
  })
288
  else:
289
- # Follow-up message about the same audio
290
- logger.info("Follow-up message, including only text in content")
291
  conversation.append({
292
  "role": "user",
293
  "content": message
294
  })
295
 
296
- # Apply chat template with error handling
297
- try:
298
- text = processor.apply_chat_template(
299
- conversation,
300
- add_generation_prompt=True,
301
- tokenize=False
302
- )
303
- logger.info(f"Chat template applied successfully")
304
- except Exception as e:
305
- logger.error(f"Error applying chat template: {e}")
306
- # Use a simplified approach if template fails
307
- text = f"{SYSTEM_PROMPT}\n\nUser: {message}\n\nAssistant:"
308
-
309
- # Generate model inputs
310
- try:
311
- inputs = processor(
312
- text=text,
313
- audios=audios,
314
- return_tensors="pt",
315
- padding=True,
316
- truncation=True
317
- )
318
-
319
- # Move inputs to the appropriate device
320
- if hasattr(model, 'device'):
321
- device = model.device
322
- else:
323
- device = next(model.parameters()).device
324
- logger.info(f"Using device: {device}")
325
- inputs = {k: v.to(device) for k, v in inputs.items()}
326
-
327
- logger.info(f"Model inputs generated")
328
- log_gpu_memory("After input preparation")
329
- except Exception as e:
330
- logger.error(f"Error generating model inputs: {e}")
331
- return f"⚠️ Error generating model inputs: {str(e)}"
332
-
333
- # Generate response from model
334
- with torch.no_grad():
335
- try:
336
- generate_ids = model.generate(
337
- **inputs,
338
- max_new_tokens=128, # Reduced from 256
339
- temperature=0.7,
340
- do_sample=True,
341
- top_p=0.9,
342
- use_cache=True # Ensure KV cache is used
343
- )
344
- logger.info(f"Response generated successfully")
345
- log_gpu_memory("After generation")
346
- except Exception as e:
347
- logger.error(f"Error during model.generate: {e}")
348
- return f"⚠️ Model generation error: {str(e)}"
349
 
350
- # Decode the response
351
- try:
352
- generate_ids = generate_ids[:, inputs["input_ids"].size(1):]
353
- response = processor.batch_decode(
354
- generate_ids,
355
- skip_special_tokens=True,
356
- clean_up_tokenization_spaces=False
357
- )[0]
358
- logger.info(f"Response decoded successfully, length: {len(response)}")
359
-
360
- # Quick validation of response
361
- if not response or response.isspace():
362
- logger.error("Empty response received from model")
363
- return "⚠️ Model returned an empty response. Please try again."
364
-
365
- # Clean up memory
366
- del inputs, generate_ids
367
- gc.collect()
368
- if torch.cuda.is_available():
369
- torch.cuda.empty_cache()
370
-
371
- return response
372
- except Exception as e:
373
- logger.error(f"Error decoding response: {e}")
374
- return f"⚠️ Error decoding response: {str(e)}"
375
-
376
- except Exception as e:
377
- logger.error(f"Unexpected error in chat_with_model: {e}", exc_info=True)
378
- return f"⚠️ An error occurred: {str(e)}"
379
- finally:
380
- # Final memory cleanup
381
- gc.collect()
382
- if torch.cuda.is_available():
383
- torch.cuda.empty_cache()
384
- log_gpu_memory("End of chat_with_model")
385
-
386
- # Function to check if URL is a valid audio file
387
- def is_valid_audio_url(url):
388
- if not url or not url.strip():
389
- return False
390
-
391
- url = url.strip().lower()
392
- return url.endswith(('.wav', '.mp3', '.ogg', '.flac', '.m4a'))
393
-
394
- # Custom theme with orange primary color and dark background
395
- orange_black_theme = gr.themes.Base(
396
- primary_hue="orange",
397
- secondary_hue="gray",
398
- neutral_hue="gray",
399
- font=[gr.themes.GoogleFont("Poppins"), "ui-sans-serif", "system-ui", "sans-serif"],
400
- )
401
-
402
- # Custom CSS for darker theme with orange accents
403
- custom_css = """
404
- :root {
405
- --orange-primary: #ff7700;
406
- --dark-bg: #1a1a1a;
407
- --darker-bg: #121212;
408
- --lightest-gray: #e0e0e0;
409
- }
410
-
411
- body {
412
- background-color: var(--darker-bg) !important;
413
- color: var(--lightest-gray) !important;
414
- font-family: 'Poppins', sans-serif !important;
415
- }
416
-
417
- .gradio-container {
418
- background-color: var(--darker-bg) !important;
419
- }
420
-
421
- button.primary {
422
- background-color: var(--orange-primary) !important;
423
- }
424
-
425
- .message.bot {
426
- background-color: var(--dark-bg) !important;
427
- }
428
- """
429
-
430
- # Gradio interface
431
- with gr.Blocks(theme=orange_black_theme, css=custom_css) as demo:
432
- gr.Markdown(
433
- """
434
- # 🎧 Music Mixing Assistant
435
- Enter an audio URL (.wav format recommended) and chat with your co-creative mixing agent!
436
 
437
- Set your audio track once, then have an extended conversation about mixing and improving that specific track.
438
- *(Note: Audio samples are limited to 15 seconds for optimal performance)*
439
- """
440
- )
441
-
442
- # Create states for chat history and audio URL
443
- audio_url_state = gr.State("")
444
-
445
- with gr.Row():
446
- with gr.Column(scale=3):
447
- # Chat interface with customized settings
448
- chatbot = gr.Chatbot(
449
- height=500,
450
- avatar_images=(None, "🎧"), # Removed user icon
451
- show_label=False,
452
- container=True,
453
- bubble_full_width=False,
454
- show_copy_button=False, # Removed copy button
455
- show_share_button=False, # Removed share button
456
- render_markdown=True
457
- )
458
-
459
- # Input area
460
- with gr.Row():
461
- message = gr.Textbox(
462
- placeholder="Ask about your mix...",
463
- show_label=False,
464
- container=False,
465
- scale=10
466
- )
467
- submit_btn = gr.Button("Send", variant="primary", scale=1)
468
-
469
- # Control buttons
470
- with gr.Row():
471
- clear_btn = gr.Button("Clear Chat", variant="secondary")
472
 
473
- with gr.Column(scale=1):
474
- # Audio URL input
475
- audio_input = gr.Textbox(
476
- label="Audio URL (.wav format)",
477
- placeholder="https://example.com/your-audio-file.wav",
478
- info="Enter URL to a WAV audio file - first 15 seconds will be analyzed"
479
- )
480
-
481
- # Add a button to set the URL
482
- set_url_btn = gr.Button("Set Audio Track", variant="primary")
483
-
484
- # Preview player (optional)
485
- audio_preview = gr.Audio(
486
- label="Audio Preview (if available)",
487
- interactive=False,
488
- visible=True
489
- )
490
-
491
- # Memory usage indicator
492
- if torch.cuda.is_available():
493
- memory_status = gr.Markdown("*GPU Memory: Initializing...*")
494
- def update_memory_status():
495
- if torch.cuda.is_available():
496
- allocated = torch.cuda.memory_allocated() / 1024**3
497
- reserved = torch.cuda.memory_reserved() / 1024**3
498
- return f"*GPU Memory: {allocated:.2f}GB allocated / {reserved:.2f}GB reserved*"
499
- return "*GPU Memory: Not available*"
500
- else:
501
- memory_status = gr.Markdown("*GPU Memory: Not available*")
502
- def update_memory_status():
503
- return "*GPU Memory: Not available*"
504
-
505
- # Display status
506
- status = gr.Markdown("*Status: Ready to assist with your mix!*")
507
-
508
- # Function to update the audio URL state and preview
509
- def update_audio_url(url):
510
- # Basic validation
511
- if not is_valid_audio_url(url):
512
- return "", gr.update(value=None), "*Status: Invalid audio URL. Please use .wav, .mp3, .ogg, .flac, or .m4a format*", update_memory_status()
513
 
514
- # Try to provide a preview if possible
 
515
  try:
516
- return url, gr.update(value=url), "*Status: Audio track set! First 15 seconds will be analyzed.*", update_memory_status()
517
- except Exception as e:
518
- # If preview fails, still set the URL but show warning
519
- return url, gr.update(value=None), f"*Status: Audio track set, but preview failed: {str(e)}*", update_memory_status()
520
-
521
- # Function to clear chat
522
- def clear_chat():
523
- return []
524
-
525
- # Set URL button logic - Combined update and clear in one function
526
- def update_and_clear_chat(url):
527
- # First update the URL
528
- result = update_audio_url(url)
529
- # Then return the values including an empty chat
530
- return result[0], result[1], [], result[2], result[3]
531
-
532
- # Set URL button
533
- set_url_btn.click(
534
- update_and_clear_chat,
535
- inputs=[audio_input],
536
- outputs=[audio_url_state, audio_preview, chatbot, status, memory_status]
537
- )
538
-
539
- # Handle submit button
540
- def respond(audio_url, message, chat_history):
541
- if not message.strip():
542
- return chat_history, "*Status: Please enter a message*", update_memory_status()
543
-
544
- # Check if audio URL is set
545
- if not audio_url or not audio_url.strip():
546
- error_msg = "No audio track set. Please set an audio URL first."
547
- chat_history.append((message, f"⚠️ {error_msg}"))
548
- return chat_history, f"*Status: {error_msg}*", update_memory_status()
549
-
550
- # Update chat history with user message immediately
551
- chat_history.append((message, None))
552
- yield chat_history, "🎵 *Analyzing your mix...*", update_memory_status()
553
 
554
- try:
555
- # Process and get response
556
- bot_message = chat_with_model(audio_url, message, chat_history[:-1])
557
-
558
- # Update the last message with the bot's response
559
- chat_history[-1] = (message, bot_message)
560
-
561
- # Return updated chat history
562
- yield chat_history, "*Status: Ready to assist with your mix!*", update_memory_status()
563
- except Exception as e:
564
- error_msg = f"Error generating response: {str(e)}"
565
- chat_history[-1] = (message, f"⚠️ {error_msg}")
566
- yield chat_history, f"*Status: {error_msg}*", update_memory_status()
567
-
568
- # Handle submit with clear input
569
- def respond_and_clear_input(audio_url, message, chat_history):
570
- # First get response updates
571
- for result in respond(audio_url, message, chat_history):
572
- # Yield each result with empty message input
573
- yield result[0], result[1], result[2], ""
574
-
575
- # Connect UI components
576
- submit_btn.click(
577
- respond_and_clear_input,
578
- inputs=[audio_url_state, message, chatbot],
579
- outputs=[chatbot, status, memory_status, message],
580
- queue=True
581
- )
582
-
583
- message.submit(
584
- respond_and_clear_input,
585
- inputs=[audio_url_state, message, chatbot],
586
- outputs=[chatbot, status, memory_status, message],
587
- queue=True
588
- )
589
-
590
- # Clear button functionality to reset everything
591
- def clear_all():
592
  gc.collect()
593
  if torch.cuda.is_available():
594
  torch.cuda.empty_cache()
595
- return [], "", None, "*Status: Chat cleared!*", update_memory_status(), ""
596
 
597
- clear_btn.click(
598
- clear_all,
599
- None,
600
- [chatbot, audio_input, audio_preview, status, memory_status, audio_url_state],
601
- queue=False
602
- )
603
-
604
- # Launch the interface
605
- if __name__ == "__main__":
606
- # Display version warning at startup
607
- try:
608
- import pkg_resources
609
- gradio_version = pkg_resources.get_distribution("gradio").version
610
- recommended_version = "4.44.1" # Update this as needed
611
- if gradio_version != recommended_version:
612
- print(f"⚠️ WARNING: You are using gradio version {gradio_version}, however version {recommended_version} is available.")
613
- print(f"⚠️ Please upgrade: pip install gradio=={recommended_version}")
614
- except:
615
- pass
616
 
617
- # Launch with optimized settings
618
- demo.launch(share=False, debug=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import gradio as gr
3
  import spaces
4
  import torch
 
5
  import numpy as np
6
  import librosa
 
 
7
  import logging
8
  import sys
9
  import gc
10
+ import time
11
+ from io import BytesIO
12
+ from urllib.request import urlopen, Request
13
+ import requests
14
+ from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration, BitsAndBytesConfig
15
 
16
  # Configure logging
17
  logging.basicConfig(
 
21
  )
22
  logger = logging.getLogger(__name__)
23
 
24
+ # Use your fine-tuned model
25
+ MODEL_ID = "mclemcrew/MixInstruct"
26
 
27
+ # Cache for model and processor to avoid reloading
28
  model_cache = None
29
  processor_cache = None
30
 
31
+ # Memory tracking function
32
  def log_gpu_memory(message=""):
33
+ """Log current GPU memory usage with a descriptive message"""
34
  if torch.cuda.is_available():
35
  allocated = torch.cuda.memory_allocated() / 1024**3
36
  reserved = torch.cuda.memory_reserved() / 1024**3
37
  logger.info(f"{message} - GPU Memory: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved")
38
+ else:
39
+ logger.info(f"{message} - Running on CPU, no GPU available")
40
 
41
  def load_model():
42
  """Load the fine-tuned model with optimized memory usage"""
 
50
  # Log initial GPU state
51
  log_gpu_memory("Before model loading")
52
 
53
+ # First clear any existing cache
 
 
 
 
54
  gc.collect()
55
  if torch.cuda.is_available():
56
  torch.cuda.empty_cache()
57
 
58
+ # Load processor first
59
+ logger.info(f"Loading processor from {MODEL_ID}")
60
+ try:
61
+ processor = AutoProcessor.from_pretrained(MODEL_ID)
62
+ logger.info("Processor loaded successfully")
63
+ except Exception as e:
64
+ logger.error(f"Error loading processor: {e}")
65
+ raise RuntimeError(f"Failed to load processor: {str(e)}")
66
+
67
+ # Define quantization config - use 4-bit quantization for memory efficiency
68
  quant_config = BitsAndBytesConfig(
69
  load_in_4bit=True,
70
+ bnb_4bit_compute_dtype=torch.float16,
71
  bnb_4bit_use_double_quant=True,
72
+ bnb_4bit_quant_type="nf4",
73
+ bnb_4bit_quant_storage=torch.uint8
74
  )
75
 
76
+ # Load the model with progressive fallbacks
77
+ logger.info(f"Loading model from {MODEL_ID} with 4-bit quantization")
78
  try:
79
+ # Primary approach: 4-bit quantization
 
 
80
  model = Qwen2AudioForConditionalGeneration.from_pretrained(
81
  MODEL_ID,
82
  quantization_config=quant_config,
83
+ device_map="auto", # Let Hugging Face determine optimal device mapping
84
  torch_dtype=torch.float16,
 
 
85
  low_cpu_mem_usage=True
86
  )
87
+ logger.info("Model loaded successfully with 4-bit quantization")
 
 
88
  except Exception as e:
89
+ # Clean up memory before fallback
90
+ logger.error(f"Error loading model with 4-bit quantization: {e}")
91
  gc.collect()
92
  if torch.cuda.is_available():
93
  torch.cuda.empty_cache()
94
 
95
+ # Try 8-bit quantization as fallback
96
  try:
97
+ logger.info("Attempting fallback to 8-bit quantization")
 
 
 
98
  quant_config_8bit = BitsAndBytesConfig(
99
  load_in_8bit=True,
100
  llm_int8_threshold=6.0
 
103
  model = Qwen2AudioForConditionalGeneration.from_pretrained(
104
  MODEL_ID,
105
  quantization_config=quant_config_8bit,
106
+ device_map="auto",
107
  torch_dtype=torch.float16
108
  )
 
 
109
  logger.info("Model loaded successfully with 8-bit quantization")
110
  except Exception as e2:
111
+ # Clean up memory before final fallback
112
+ logger.error(f"Error loading model with 8-bit quantization: {e2}")
113
  gc.collect()
114
  if torch.cuda.is_available():
115
  torch.cuda.empty_cache()
116
 
117
+ # Final fallback - try to load with fp16 and CPU offloading
118
  try:
119
+ logger.info("Attempting final fallback with CPU offloading")
120
+ model = Qwen2AudioForConditionalGeneration.from_pretrained(
121
+ MODEL_ID,
122
+ torch_dtype=torch.float16,
123
+ device_map="auto",
124
+ offload_folder="offload",
125
+ offload_state_dict=True
126
+ )
127
+ logger.info("Model loaded successfully with CPU offloading")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  except Exception as e3:
129
+ logger.error(f"All loading attempts failed: {e3}")
130
+ raise RuntimeError("Could not load model after multiple attempts")
131
+
132
+ # Verify model loaded correctly
133
+ if model is None:
134
+ raise RuntimeError("Model failed to load but no exception was raised")
135
+
136
+ # Set model to evaluation mode
137
+ model.eval()
138
 
139
  # Cache the model and processor
140
  model_cache = model
141
  processor_cache = processor
142
 
143
+ # Log final memory state
144
+ log_gpu_memory("After model loading")
145
+
146
  return model, processor
147
 
148
+ def process_audio(audio_path, processor):
149
+ """
150
+ Process audio file from URL or local path
151
+
152
+ Args:
153
+ audio_path: URL or path to audio file
154
+ processor: Model processor
155
+
156
+ Returns:
157
+ Processed audio data as numpy array
158
+ """
159
+ logger.info(f"Processing audio from: {audio_path}")
160
+
161
  try:
162
+ # Get target sampling rate from processor
 
163
  target_sr = int(processor.feature_extractor.sampling_rate)
164
  logger.info(f"Target sampling rate: {target_sr}")
165
 
166
+ # Determine maximum audio length (15 seconds)
167
+ max_seconds = 15
168
+ max_samples = max_seconds * target_sr
169
 
170
+ # Load audio data based on source
171
+ if audio_path.startswith(('http://', 'https://')):
172
+ # Web URL handling with proper headers to avoid 403 errors
173
  try:
174
+ # First try with requests for better error handling
175
+ headers = {
176
+ 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
177
+ }
178
+ response = requests.get(audio_path, headers=headers)
179
  response.raise_for_status()
180
  audio_bytes = BytesIO(response.content)
181
+ logger.info(f"Successfully downloaded audio with requests: {len(response.content)} bytes")
182
+ except Exception as req_err:
183
+ # Fallback to urlopen
184
+ logger.warning(f"Requests download failed, trying urlopen: {req_err}")
185
+ request = Request(audio_path, headers={
186
+ 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
187
+ })
188
+ audio_bytes = BytesIO(urlopen(request).read())
189
+
190
+ # Load audio with librosa
191
+ audio_data, sr_loaded = librosa.load(audio_bytes, sr=None)
192
  else:
193
+ # Local file handling
194
+ audio_data, sr_loaded = librosa.load(audio_path, sr=None)
 
195
 
 
 
196
  logger.info(f"Audio loaded with shape: {audio_data.shape}, original SR: {sr_loaded}")
197
 
 
 
 
 
198
  # Resample if needed
199
  if sr_loaded != target_sr:
200
  logger.info(f"Resampling from {sr_loaded} Hz to {target_sr} Hz")
201
  audio_data = librosa.resample(audio_data, orig_sr=sr_loaded, target_sr=target_sr)
202
 
203
+ # Truncate to maximum length
 
 
204
  if len(audio_data) > max_samples:
205
+ logger.info(f"Truncating audio from {len(audio_data)} to {max_samples} samples ({max_seconds} seconds)")
206
  audio_data = audio_data[:max_samples]
207
 
208
  # Ensure audio is float32
209
  audio_data = audio_data.astype(np.float32)
210
 
211
+ # Print audio stats
212
+ logger.info(f"Processed audio shape: {audio_data.shape}, min: {audio_data.min()}, max: {audio_data.max()}")
213
+
214
  return audio_data
215
+
216
  except Exception as e:
217
+ logger.error(f"Error processing audio: {e}", exc_info=True)
218
+ # Return a small empty array instead of None to avoid downstream errors
219
+ return np.zeros(target_sr * 3, dtype=np.float32) # 3 seconds of silence as fallback
 
 
 
 
220
 
221
+ # Add retry decorator for reliability
222
+ def with_retry(max_retries=3, delay=1.0):
223
+ """
224
+ Decorator to retry functions with exponential backoff
225
+ """
226
+ def decorator(func):
227
+ def wrapper(*args, **kwargs):
228
+ retries = 0
229
+ current_delay = delay
230
+
231
+ while retries < max_retries:
232
+ try:
233
+ return func(*args, **kwargs)
234
+ except Exception as e:
235
+ retries += 1
236
+ if retries >= max_retries:
237
+ logger.error(f"Function {func.__name__} failed after {max_retries} retries: {e}")
238
+ raise
239
+
240
+ logger.warning(f"Retry {retries}/{max_retries} for {func.__name__}: {e}")
241
+ time.sleep(current_delay)
242
+ current_delay *= 2 # Exponential backoff
243
+
244
+ return None # Should never reach here
245
+ return wrapper
246
+ return decorator
247
+
248
+ # Function to validate audio URLs
249
+ def is_valid_audio_url(url):
250
+ """Check if a URL likely points to an audio file"""
251
+ if not url or not isinstance(url, str) or not url.strip():
252
+ return False
253
+
254
+ url = url.strip().lower()
255
+ audio_extensions = ('.wav', '.mp3', '.ogg', '.flac', '.m4a', '.aac', '.wma')
256
+
257
+ # Check if URL ends with a known audio extension
258
+ if any(url.endswith(ext) for ext in audio_extensions):
259
+ return True
260
+
261
+ # Check for common audio hosting patterns
262
+ audio_hosts = ('soundcloud.com', 'bandcamp.com', 'freesound.org')
263
+ if any(host in url for host in audio_hosts):
264
+ return True
265
+
266
+ return False
267
+
268
+ @with_retry(max_retries=2, delay=1.0)
269
+ @spaces.GPU(duration=60) # Reduced duration for more reliable performance
270
  def chat_with_model(audio_url, message, chat_history):
271
+ """
272
+ Generate response from the model using an audio URL
273
 
274
+ Args:
275
+ audio_url: URL to audio file
276
+ message: User message text
277
+ chat_history: Previous conversation history
278
+
279
+ Returns:
280
+ Model's response text
281
+ """
282
+ logger.info(f"Starting chat with audio_url: {audio_url}, message: {message}")
283
+ log_gpu_memory("Starting chat_with_model")
284
 
285
+ # Validate inputs
286
  if not audio_url or not audio_url.strip():
287
+ return "⚠️ Please provide an audio URL before sending a message."
288
+
289
+ if not message or not message.strip():
290
+ return "⚠️ Please enter a message to get a response."
291
 
292
  try:
293
+ # Load model and processor
294
  model, processor = load_model()
295
 
296
+ # Process audio file
297
+ audio_data = process_audio(audio_url, processor)
298
+ if audio_data is None:
299
+ return "⚠️ Could not process the audio file. Please check the URL and try again."
300
 
301
+ # Store processed audio in a list for model input
302
+ audios = [audio_data]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
 
304
+ # Define system prompt
305
+ system_prompt = "You are an expert audio engineer assisting with music production and mixing. Provide clear, specific advice on audio engineering techniques, mixing adjustments, and production decisions based on the audio samples and the user's questions. Focus on practical, actionable guidance. Be as specific as possible when answering the user's questions about the mix."
306
 
307
+ # Build conversation structure
 
 
 
308
  conversation = [
309
+ {"role": "system", "content": system_prompt}
310
  ]
311
 
312
+ # Add chat history (limited to last 3 exchanges to save memory)
313
+ history_limit = min(len(chat_history), 3)
314
  for user_msg, bot_msg in chat_history[-history_limit:]:
315
  if user_msg:
316
  conversation.append({"role": "user", "content": user_msg})
317
  if bot_msg:
318
  conversation.append({"role": "assistant", "content": bot_msg})
319
 
320
+ # Determine if this is the first message with this audio
321
+ is_first_message = len(chat_history) == 0
322
 
323
+ # Add current message with audio if it's the first message
324
+ if is_first_message:
325
  # First message includes audio
 
326
  conversation.append({
327
  "role": "user",
328
  "content": [
 
331
  ]
332
  })
333
  else:
334
+ # Follow-up messages just include text
 
335
  conversation.append({
336
  "role": "user",
337
  "content": message
338
  })
339
 
340
+ # Apply chat template
341
+ logger.info(f"Formatting conversation with {len(conversation)} messages")
342
+ text = processor.apply_chat_template(
343
+ conversation,
344
+ add_generation_prompt=True,
345
+ tokenize=False
346
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347
 
348
+ # Create model inputs
349
+ logger.info("Preparing model inputs")
350
+ inputs = processor(
351
+ text=text,
352
+ audios=audios if is_first_message else None, # Only include audio on first message
353
+ return_tensors="pt",
354
+ padding=True,
355
+ truncation=True
356
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357
 
358
+ # Move inputs to GPU if available
359
+ device = next(model.parameters()).device
360
+ inputs = {k: v.to(device) for k, v in inputs.items()}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
361
 
362
+ log_gpu_memory("Before generation")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
363
 
364
+ # Generate response with optimized settings
365
+ logger.info("Generating response")
366
  try:
367
+ with torch.no_grad():
368
+ generate_ids = model.generate(
369
+ **inputs,
370
+ max_new_tokens=150, # Slightly reduced for reliability
371
+ do_sample=True, # Enable sampling for more natural responses
372
+ temperature=0.7, # Moderate temperature
373
+ top_p=0.9, # Nucleus sampling for focused yet diverse outputs
374
+ num_beams=1, # Disable beam search for faster generation
375
+ use_cache=True, # Use KV cache
376
+ repetition_penalty=1.1 # Light penalty to avoid repetition
377
+ )
378
+ except Exception as gen_error:
379
+ logger.error(f"Generation error: {gen_error}")
380
+ # Try a simpler generation approach as fallback
381
+ with torch.no_grad():
382
+ generate_ids = model.generate(
383
+ **inputs,
384
+ max_new_tokens=100, # Even shorter for reliability
385
+ do_sample=False, # Disable sampling
386
+ num_beams=1, # No beam search
387
+ use_cache=True # Still use KV cache
388
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
389
 
390
+ # Extract only the generated response (not the input)
391
+ logger.info("Processing generated response")
392
+ generate_ids = generate_ids[:, inputs["input_ids"].size(1):]
393
+ response = processor.batch_decode(
394
+ generate_ids,
395
+ skip_special_tokens=True,
396
+ clean_up_tokenization_spaces=False
397
+ )[0]
398
+
399
+ # Clean up memory
400
+ del inputs, generate_ids, audios, audio_data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
401
  gc.collect()
402
  if torch.cuda.is_available():
403
  torch.cuda.empty_cache()
 
404
 
405
+ log_gpu_memory("After generation")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
406
 
407
+ # Format and return response
408
+ logger.info(f"Generated response of length {len(response)}")
409
+ return response.strip()
410
+
411
+ except RuntimeError as e:
412
+ # Handle CUDA out of memory errors specially
413
+ if "CUDA out of memory" in str(e):
414
+ logger.error(f"CUDA OOM error: {e}")
415
+ return "⚠️ Out of GPU memory. Please try with a shorter audio clip (under 15 seconds) or refresh the page."
416
+ else:
417
+ logger.error(f"Runtime error: {e}", exc_info=True)
418
+ return f"⚠️ An error occurred: {str(e)}"
419
+ except Exception as e:
420
+ logger.error(f"Unexpected error in chat_with_model: {e}", exc_info=True)
421
+ return f"⚠️ Something went wrong: {str(e)}"
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
  gradio==4.44.1
2
- git+https://github.com/huggingface/transformers.git
3
  torch>=2.0.1
4
  numpy>=1.24.3
5
  librosa>=0.10.1
@@ -7,7 +7,7 @@ accelerate>=0.23.0
7
  requests>=2.32.0
8
  pillow>=9.5.0
9
  huggingface_hub>=0.16.4
10
- spaces
11
  urllib3>=1.26.16
12
  soundfile>=0.12.1
13
  bitsandbytes>=0.42.0
 
1
  gradio==4.44.1
2
+ transformers>=4.35.0
3
  torch>=2.0.1
4
  numpy>=1.24.3
5
  librosa>=0.10.1
 
7
  requests>=2.32.0
8
  pillow>=9.5.0
9
  huggingface_hub>=0.16.4
10
+ spaces>=0.19.1
11
  urllib3>=1.26.16
12
  soundfile>=0.12.1
13
  bitsandbytes>=0.42.0
setup_examples.py DELETED
@@ -1,52 +0,0 @@
1
- import os
2
- import requests
3
- import logging
4
-
5
- # Configure logging
6
- logging.basicConfig(
7
- level=logging.INFO,
8
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
9
- )
10
- logger = logging.getLogger(__name__)
11
-
12
- def setup_examples():
13
- """Download example audio files for the app"""
14
- # Create examples directory if it doesn't exist
15
- examples_dir = "examples"
16
- os.makedirs(examples_dir, exist_ok=True)
17
-
18
- # Example files to download - you can replace these with your own examples
19
- examples = [
20
- {
21
- "name": "guitar_mix_example.mp3",
22
- "url": "https://freesound.org/data/previews/612/612850_5674468-lq.mp3" # Guitar example from freesound
23
- },
24
- {
25
- "name": "vocals_example.mp3",
26
- "url": "https://freesound.org/data/previews/336/336590_5674468-lq.mp3" # Vocal example from freesound
27
- }
28
- ]
29
-
30
- # Download each example
31
- for example in examples:
32
- file_path = os.path.join(examples_dir, example["name"])
33
-
34
- # Skip if file already exists
35
- if os.path.exists(file_path):
36
- logger.info(f"File {example['name']} already exists, skipping download")
37
- continue
38
-
39
- try:
40
- logger.info(f"Downloading {example['name']} from {example['url']}")
41
- response = requests.get(example["url"])
42
- response.raise_for_status()
43
-
44
- with open(file_path, "wb") as f:
45
- f.write(response.content)
46
-
47
- logger.info(f"Successfully downloaded {example['name']}")
48
- except Exception as e:
49
- logger.error(f"Error downloading {example['name']}: {e}")
50
-
51
- if __name__ == "__main__":
52
- setup_examples()