mclemcrew commited on
Commit
8d531a8
·
1 Parent(s): e59de55

updates to gradio app

Browse files
Files changed (1) hide show
  1. app.py +603 -49
app.py CHANGED
@@ -1,64 +1,618 @@
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
 
 
 
 
 
 
 
 
 
3
 
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
 
 
 
8
 
 
 
9
 
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
 
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
 
25
 
26
- messages.append({"role": "user", "content": message})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- response = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
- response += token
40
- yield response
 
 
 
 
 
41
 
 
 
 
 
 
 
 
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- demo = gr.ChatInterface(
47
- respond,
48
- additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="Top-p (nucleus sampling)",
58
- ),
59
- ],
60
- )
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
 
63
  if __name__ == "__main__":
64
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
  import gradio as gr
3
+ import spaces
4
+ import torch
5
+ from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration, BitsAndBytesConfig
6
+ import numpy as np
7
+ import librosa
8
+ from urllib.request import urlopen
9
+ from io import BytesIO
10
+ import logging
11
+ import sys
12
+ import gc
13
 
14
+ # Configure logging
15
+ logging.basicConfig(
16
+ level=logging.INFO,
17
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
18
+ handlers=[logging.StreamHandler(sys.stdout)]
19
+ )
20
+ logger = logging.getLogger(__name__)
21
 
22
+ # Update to use the merged model
23
+ MODEL_ID = "mclemcrew/Qwen-Audio-Mix-Instruct"
24
 
25
+ # Cache for model and processor
26
+ model_cache = None
27
+ processor_cache = None
 
 
 
 
 
 
28
 
29
+ # Memory tracking
30
+ def log_gpu_memory(message=""):
31
+ if torch.cuda.is_available():
32
+ allocated = torch.cuda.memory_allocated() / 1024**3
33
+ reserved = torch.cuda.memory_reserved() / 1024**3
34
+ logger.info(f"{message} - GPU Memory: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved")
35
 
36
+ def load_model():
37
+ """Load the fine-tuned model with optimized memory usage"""
38
+ global model_cache, processor_cache
39
+
40
+ # Return cached model if available
41
+ if model_cache is not None and processor_cache is not None:
42
+ logger.info("Using cached model and processor")
43
+ return model_cache, processor_cache
44
+
45
+ # Log initial GPU state
46
+ log_gpu_memory("Before model loading")
47
+
48
+ # Load processor
49
+ logger.info(f"Loading processor from {MODEL_ID}")
50
+ processor = AutoProcessor.from_pretrained(MODEL_ID)
51
+
52
+ # Clean up memory
53
+ gc.collect()
54
+ if torch.cuda.is_available():
55
+ torch.cuda.empty_cache()
56
+
57
+ # Define proper quantization config - using 4-bit quantization
58
+ quant_config = BitsAndBytesConfig(
59
+ load_in_4bit=True,
60
+ bnb_4bit_compute_dtype=torch.float16, # Match training dtype
61
+ bnb_4bit_use_double_quant=True,
62
+ bnb_4bit_quant_type="nf4"
63
+ )
64
+
65
+ try:
66
+ logger.info("Loading model with optimized 4-bit quantization")
67
+
68
+ # Load with quantization and offloading for memory efficiency
69
+ model = Qwen2AudioForConditionalGeneration.from_pretrained(
70
+ MODEL_ID,
71
+ quantization_config=quant_config,
72
+ device_map="auto",
73
+ torch_dtype=torch.float16,
74
+ offload_folder="offload",
75
+ offload_state_dict=True,
76
+ low_cpu_mem_usage=True
77
+ )
78
+
79
+ log_gpu_memory("After optimized model loading")
80
+ logger.info("Model loaded successfully with optimized approach")
81
+ except Exception as e:
82
+ logger.error(f"Error loading model with optimized approach: {e}")
83
+ gc.collect()
84
+ if torch.cuda.is_available():
85
+ torch.cuda.empty_cache()
86
+
87
+ try:
88
+ # Fallback to 8-bit quantization (more stable but less compression)
89
+ logger.info("Attempting 8-bit quantization fallback")
90
+
91
+ from transformers import BitsAndBytesConfig
92
+ quant_config_8bit = BitsAndBytesConfig(
93
+ load_in_8bit=True,
94
+ llm_int8_threshold=6.0
95
+ )
96
+
97
+ model = Qwen2AudioForConditionalGeneration.from_pretrained(
98
+ MODEL_ID,
99
+ quantization_config=quant_config_8bit,
100
+ device_map="auto",
101
+ torch_dtype=torch.float16
102
+ )
103
+
104
+ log_gpu_memory("After 8-bit fallback loading")
105
+ logger.info("Model loaded successfully with 8-bit quantization")
106
+ except Exception as e2:
107
+ logger.error(f"Error loading with 8-bit quantization: {e2}")
108
+ gc.collect()
109
+ if torch.cuda.is_available():
110
+ torch.cuda.empty_cache()
111
+
112
+ try:
113
+ # Fallback to dummy model as last resort
114
+ logger.warning("Loading dummy placeholder model")
115
+ class DummyAudioModel:
116
+ def __init__(self):
117
+ self.device = torch.device("cpu")
118
+ self.dummy_parameters = [torch.tensor([0.0])]
119
+
120
+ def generate(self, **kwargs):
121
+ input_ids = kwargs.get("input_ids", None)
122
+ if input_ids is not None:
123
+ batch_size, seq_len = input_ids.shape
124
+ dummy_output = torch.ones((batch_size, seq_len + 20), dtype=torch.long)
125
+ dummy_output[:, :seq_len] = input_ids
126
+ dummy_output[:, seq_len:] = 100
127
+ return dummy_output
128
+ else:
129
+ return torch.ones((1, 30), dtype=torch.long)
130
+
131
+ def parameters(self):
132
+ return iter(self.dummy_parameters)
133
+
134
+ def to(self, device):
135
+ self.device = device
136
+ return self
137
+
138
+ model = DummyAudioModel()
139
+ logger.warning("Created dummy model placeholder - no real functionality available")
140
+ except Exception as e3:
141
+ logger.error(f"Failed to create dummy model: {e3}")
142
+ raise RuntimeError(f"Could not load any model version after multiple attempts")
143
+
144
+ # Cache the model and processor
145
+ model_cache = model
146
+ processor_cache = processor
147
+
148
+ return model, processor
149
 
150
+ def process_audio_from_url(audio_url, processor):
151
+ """Process audio file from URL for model input with optimized memory usage"""
152
+ try:
153
+ logger.info(f"Processing audio from URL: {audio_url}")
154
+ # Get processor's sampling rate
155
+ target_sr = int(processor.feature_extractor.sampling_rate)
156
+ logger.info(f"Target sampling rate: {target_sr}")
157
+
158
+ # Audio bytes container
159
+ audio_bytes = None
160
+
161
+ # Handle various URL formats
162
+ if audio_url.startswith(('http://', 'https://')):
163
+ # For web URLs
164
+ try:
165
+ import requests
166
+ response = requests.get(audio_url)
167
+ response.raise_for_status()
168
+ audio_bytes = BytesIO(response.content)
169
+ # Free memory
170
+ del response
171
+ except Exception as req_error:
172
+ logger.info(f"Requests failed, falling back to urlopen: {req_error}")
173
+ audio_bytes = BytesIO(urlopen(audio_url).read())
174
+ elif audio_url.startswith('file://'):
175
+ # For local file URLs
176
+ file_path = audio_url[7:] # Remove 'file://' prefix
177
+ with open(file_path, 'rb') as f:
178
+ audio_bytes = BytesIO(f.read())
179
+ else:
180
+ # Try as a local file path
181
+ with open(audio_url, 'rb') as f:
182
+ audio_bytes = BytesIO(f.read())
183
+
184
+ # Load and resample audio
185
+ audio_data, sr_loaded = librosa.load(audio_bytes, sr=None)
186
+ logger.info(f"Audio loaded with shape: {audio_data.shape}, original SR: {sr_loaded}")
187
+
188
+ # Free memory
189
+ del audio_bytes
190
+ gc.collect()
191
+
192
+ # Resample if needed
193
+ if sr_loaded != target_sr:
194
+ logger.info(f"Resampling from {sr_loaded} Hz to {target_sr} Hz")
195
+ audio_data = librosa.resample(audio_data, orig_sr=sr_loaded, target_sr=target_sr)
196
+
197
+ # Reduce to 15 seconds maximum (was 30 seconds before)
198
+ max_seconds = 15
199
+ max_samples = max_seconds * target_sr
200
+ if len(audio_data) > max_samples:
201
+ logger.info(f"Limiting audio to {max_seconds} seconds for memory efficiency")
202
+ audio_data = audio_data[:max_samples]
203
+
204
+ # Ensure audio is float32
205
+ audio_data = audio_data.astype(np.float32)
206
+
207
+ return audio_data
208
+ except Exception as e:
209
+ logger.error(f"Error processing audio from URL {audio_url}: {e}", exc_info=True)
210
+ return None
211
+ finally:
212
+ # Clean up any lingering memory
213
+ if 'audio_bytes' in locals() and audio_bytes is not None:
214
+ del audio_bytes
215
+ gc.collect()
216
 
217
+ @spaces.GPU(duration=120)
218
+ def chat_with_model(audio_url, message, chat_history):
219
+ """Generate response from the model using an audio URL"""
220
+ logger.info(f"Starting chat_with_model with audio_url: {audio_url}, message: {message}")
221
+
222
+ # Log initial memory state
223
+ log_gpu_memory("At start of chat_with_model")
224
+
225
+ # Validate that audio URL is provided
226
+ if not audio_url or not audio_url.strip():
227
+ return "⚠️ Please set an audio track URL first before chatting."
228
+
229
+ try:
230
+ # Load model and processor on demand
231
+ model, processor = load_model()
232
+
233
+ # Log memory after model load
234
+ log_gpu_memory("After model load")
235
+
236
+ # Check if we're using a dummy model
237
+ is_dummy = hasattr(model, '__class__') and model.__class__.__name__ == 'DummyAudioModel'
238
+
239
+ if is_dummy:
240
+ logger.warning("Using dummy model - providing generic response")
241
+ return (
242
+ "⚠️ I'm currently having trouble analyzing your audio due to technical limitations "
243
+ "in this environment. The model requires more GPU memory than is available. "
244
+ "Please try a different audio file or contact the developer for assistance."
245
+ )
246
+
247
+ # Process audio
248
+ audios = []
249
+ audio_data = process_audio_from_url(audio_url, processor)
250
+ if audio_data is not None:
251
+ audios.append(audio_data)
252
+ else:
253
+ return "⚠️ Failed to process audio from the provided URL. Please check that the URL is valid and accessible."
254
+
255
+ # Log memory after audio processing
256
+ log_gpu_memory("After audio processing")
257
+
258
+ # System prompt for the model
259
+ SYSTEM_PROMPT = "You are an expert audio engineer assisting with music production and mixing. Provide clear, specific advice on audio engineering techniques, mixing adjustments, and production decisions based on the audio samples and the user's questions. Focus on practical, actionable guidance. Be as specific as possible when answering the user's questions about the mix."
260
+
261
+ # Start with system prompt
262
+ conversation = [
263
+ {"role": "system", "content": SYSTEM_PROMPT}
264
+ ]
265
+
266
+ # Add chat history - limited to last 5 exchanges to save memory
267
+ history_limit = min(len(chat_history), 5)
268
+ for user_msg, bot_msg in chat_history[-history_limit:]:
269
+ if user_msg:
270
+ conversation.append({"role": "user", "content": user_msg})
271
+ if bot_msg:
272
+ conversation.append({"role": "assistant", "content": bot_msg})
273
+
274
+ # Determine if this is the first message with a new audio
275
+ is_first_message_with_audio = len(chat_history) == 0
276
+
277
+ # Format user message based on whether it's the first message with audio
278
+ if is_first_message_with_audio:
279
+ # First message includes audio
280
+ logger.info("First message with audio, including audio in content")
281
+ conversation.append({
282
+ "role": "user",
283
+ "content": [
284
+ {"type": "audio", "audio_url": audio_url},
285
+ {"type": "text", "text": message}
286
+ ]
287
+ })
288
+ else:
289
+ # Follow-up message about the same audio
290
+ logger.info("Follow-up message, including only text in content")
291
+ conversation.append({
292
+ "role": "user",
293
+ "content": message
294
+ })
295
+
296
+ # Apply chat template with error handling
297
+ try:
298
+ text = processor.apply_chat_template(
299
+ conversation,
300
+ add_generation_prompt=True,
301
+ tokenize=False
302
+ )
303
+ logger.info(f"Chat template applied successfully")
304
+ except Exception as e:
305
+ logger.error(f"Error applying chat template: {e}")
306
+ # Use a simplified approach if template fails
307
+ text = f"{SYSTEM_PROMPT}\n\nUser: {message}\n\nAssistant:"
308
+
309
+ # Generate model inputs
310
+ try:
311
+ inputs = processor(
312
+ text=text,
313
+ audios=audios,
314
+ return_tensors="pt",
315
+ padding=True,
316
+ truncation=True
317
+ )
318
+
319
+ # Move inputs to the appropriate device
320
+ if hasattr(model, 'device'):
321
+ device = model.device
322
+ else:
323
+ device = next(model.parameters()).device
324
+ logger.info(f"Using device: {device}")
325
+ inputs = {k: v.to(device) for k, v in inputs.items()}
326
+
327
+ logger.info(f"Model inputs generated")
328
+ log_gpu_memory("After input preparation")
329
+ except Exception as e:
330
+ logger.error(f"Error generating model inputs: {e}")
331
+ return f"⚠️ Error generating model inputs: {str(e)}"
332
+
333
+ # Generate response from model
334
+ with torch.no_grad():
335
+ try:
336
+ generate_ids = model.generate(
337
+ **inputs,
338
+ max_new_tokens=128, # Reduced from 256
339
+ temperature=0.7,
340
+ do_sample=True,
341
+ top_p=0.9,
342
+ use_cache=True # Ensure KV cache is used
343
+ )
344
+ logger.info(f"Response generated successfully")
345
+ log_gpu_memory("After generation")
346
+ except Exception as e:
347
+ logger.error(f"Error during model.generate: {e}")
348
+ return f"⚠️ Model generation error: {str(e)}"
349
+
350
+ # Decode the response
351
+ try:
352
+ generate_ids = generate_ids[:, inputs["input_ids"].size(1):]
353
+ response = processor.batch_decode(
354
+ generate_ids,
355
+ skip_special_tokens=True,
356
+ clean_up_tokenization_spaces=False
357
+ )[0]
358
+ logger.info(f"Response decoded successfully, length: {len(response)}")
359
+
360
+ # Quick validation of response
361
+ if not response or response.isspace():
362
+ logger.error("Empty response received from model")
363
+ return "⚠️ Model returned an empty response. Please try again."
364
+
365
+ # Clean up memory
366
+ del inputs, generate_ids
367
+ gc.collect()
368
+ if torch.cuda.is_available():
369
+ torch.cuda.empty_cache()
370
+
371
+ return response
372
+ except Exception as e:
373
+ logger.error(f"Error decoding response: {e}")
374
+ return f"⚠️ Error decoding response: {str(e)}"
375
+
376
+ except Exception as e:
377
+ logger.error(f"Unexpected error in chat_with_model: {e}", exc_info=True)
378
+ return f"⚠️ An error occurred: {str(e)}"
379
+ finally:
380
+ # Final memory cleanup
381
+ gc.collect()
382
+ if torch.cuda.is_available():
383
+ torch.cuda.empty_cache()
384
+ log_gpu_memory("End of chat_with_model")
385
 
386
+ # Function to check if URL is a valid audio file
387
+ def is_valid_audio_url(url):
388
+ if not url or not url.strip():
389
+ return False
390
+
391
+ url = url.strip().lower()
392
+ return url.endswith(('.wav', '.mp3', '.ogg', '.flac', '.m4a'))
393
 
394
+ # Custom theme with orange primary color and dark background
395
+ orange_black_theme = gr.themes.Base(
396
+ primary_hue="orange",
397
+ secondary_hue="gray",
398
+ neutral_hue="gray",
399
+ font=[gr.themes.GoogleFont("Poppins"), "ui-sans-serif", "system-ui", "sans-serif"],
400
+ )
401
 
402
+ # Custom CSS for darker theme with orange accents
403
+ custom_css = """
404
+ :root {
405
+ --orange-primary: #ff7700;
406
+ --dark-bg: #1a1a1a;
407
+ --darker-bg: #121212;
408
+ --lightest-gray: #e0e0e0;
409
+ }
410
+
411
+ body {
412
+ background-color: var(--darker-bg) !important;
413
+ color: var(--lightest-gray) !important;
414
+ font-family: 'Poppins', sans-serif !important;
415
+ }
416
+
417
+ .gradio-container {
418
+ background-color: var(--darker-bg) !important;
419
+ }
420
+
421
+ button.primary {
422
+ background-color: var(--orange-primary) !important;
423
+ }
424
+
425
+ .message.bot {
426
+ background-color: var(--dark-bg) !important;
427
+ }
428
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
429
 
430
+ # Gradio interface
431
+ with gr.Blocks(theme=orange_black_theme, css=custom_css) as demo:
432
+ gr.Markdown(
433
+ """
434
+ # 🎧 Music Mixing Assistant
435
+ Enter an audio URL (.wav format recommended) and chat with your co-creative mixing agent!
436
+
437
+ Set your audio track once, then have an extended conversation about mixing and improving that specific track.
438
+ *(Note: Audio samples are limited to 15 seconds for optimal performance)*
439
+ """
440
+ )
441
+
442
+ # Create states for chat history and audio URL
443
+ audio_url_state = gr.State("")
444
+
445
+ with gr.Row():
446
+ with gr.Column(scale=3):
447
+ # Chat interface with customized settings
448
+ chatbot = gr.Chatbot(
449
+ height=500,
450
+ avatar_images=(None, "🎧"), # Removed user icon
451
+ show_label=False,
452
+ container=True,
453
+ bubble_full_width=False,
454
+ show_copy_button=False, # Removed copy button
455
+ show_share_button=False, # Removed share button
456
+ render_markdown=True
457
+ )
458
+
459
+ # Input area
460
+ with gr.Row():
461
+ message = gr.Textbox(
462
+ placeholder="Ask about your mix...",
463
+ show_label=False,
464
+ container=False,
465
+ scale=10
466
+ )
467
+ submit_btn = gr.Button("Send", variant="primary", scale=1)
468
+
469
+ # Control buttons
470
+ with gr.Row():
471
+ clear_btn = gr.Button("Clear Chat", variant="secondary")
472
+
473
+ with gr.Column(scale=1):
474
+ # Audio URL input
475
+ audio_input = gr.Textbox(
476
+ label="Audio URL (.wav format)",
477
+ placeholder="https://example.com/your-audio-file.wav",
478
+ info="Enter URL to a WAV audio file - first 15 seconds will be analyzed"
479
+ )
480
+
481
+ # Add a button to set the URL
482
+ set_url_btn = gr.Button("Set Audio Track", variant="primary")
483
+
484
+ # Preview player (optional)
485
+ audio_preview = gr.Audio(
486
+ label="Audio Preview (if available)",
487
+ interactive=False,
488
+ visible=True
489
+ )
490
+
491
+ # Memory usage indicator
492
+ if torch.cuda.is_available():
493
+ memory_status = gr.Markdown("*GPU Memory: Initializing...*")
494
+ def update_memory_status():
495
+ if torch.cuda.is_available():
496
+ allocated = torch.cuda.memory_allocated() / 1024**3
497
+ reserved = torch.cuda.memory_reserved() / 1024**3
498
+ return f"*GPU Memory: {allocated:.2f}GB allocated / {reserved:.2f}GB reserved*"
499
+ return "*GPU Memory: Not available*"
500
+ else:
501
+ memory_status = gr.Markdown("*GPU Memory: Not available*")
502
+ def update_memory_status():
503
+ return "*GPU Memory: Not available*"
504
+
505
+ # Display status
506
+ status = gr.Markdown("*Status: Ready to assist with your mix!*")
507
+
508
+ # Function to update the audio URL state and preview
509
+ def update_audio_url(url):
510
+ # Basic validation
511
+ if not is_valid_audio_url(url):
512
+ return "", gr.update(value=None), "*Status: Invalid audio URL. Please use .wav, .mp3, .ogg, .flac, or .m4a format*", update_memory_status()
513
+
514
+ # Try to provide a preview if possible
515
+ try:
516
+ return url, gr.update(value=url), "*Status: Audio track set! First 15 seconds will be analyzed.*", update_memory_status()
517
+ except Exception as e:
518
+ # If preview fails, still set the URL but show warning
519
+ return url, gr.update(value=None), f"*Status: Audio track set, but preview failed: {str(e)}*", update_memory_status()
520
+
521
+ # Function to clear chat
522
+ def clear_chat():
523
+ return []
524
+
525
+ # Set URL button logic - Combined update and clear in one function
526
+ def update_and_clear_chat(url):
527
+ # First update the URL
528
+ result = update_audio_url(url)
529
+ # Then return the values including an empty chat
530
+ return result[0], result[1], [], result[2], result[3]
531
+
532
+ # Set URL button
533
+ set_url_btn.click(
534
+ update_and_clear_chat,
535
+ inputs=[audio_input],
536
+ outputs=[audio_url_state, audio_preview, chatbot, status, memory_status]
537
+ )
538
+
539
+ # Handle submit button
540
+ def respond(audio_url, message, chat_history):
541
+ if not message.strip():
542
+ return chat_history, "*Status: Please enter a message*", update_memory_status()
543
+
544
+ # Check if audio URL is set
545
+ if not audio_url or not audio_url.strip():
546
+ error_msg = "No audio track set. Please set an audio URL first."
547
+ chat_history.append((message, f"⚠️ {error_msg}"))
548
+ return chat_history, f"*Status: {error_msg}*", update_memory_status()
549
+
550
+ # Update chat history with user message immediately
551
+ chat_history.append((message, None))
552
+ yield chat_history, "🎵 *Analyzing your mix...*", update_memory_status()
553
+
554
+ try:
555
+ # Process and get response
556
+ bot_message = chat_with_model(audio_url, message, chat_history[:-1])
557
+
558
+ # Update the last message with the bot's response
559
+ chat_history[-1] = (message, bot_message)
560
+
561
+ # Return updated chat history
562
+ yield chat_history, "*Status: Ready to assist with your mix!*", update_memory_status()
563
+ except Exception as e:
564
+ error_msg = f"Error generating response: {str(e)}"
565
+ chat_history[-1] = (message, f"⚠️ {error_msg}")
566
+ yield chat_history, f"*Status: {error_msg}*", update_memory_status()
567
+
568
+ # Handle submit with clear input
569
+ def respond_and_clear_input(audio_url, message, chat_history):
570
+ # First get response updates
571
+ for result in respond(audio_url, message, chat_history):
572
+ # Yield each result with empty message input
573
+ yield result[0], result[1], result[2], ""
574
+
575
+ # Connect UI components
576
+ submit_btn.click(
577
+ respond_and_clear_input,
578
+ inputs=[audio_url_state, message, chatbot],
579
+ outputs=[chatbot, status, memory_status, message],
580
+ queue=True
581
+ )
582
+
583
+ message.submit(
584
+ respond_and_clear_input,
585
+ inputs=[audio_url_state, message, chatbot],
586
+ outputs=[chatbot, status, memory_status, message],
587
+ queue=True
588
+ )
589
+
590
+ # Clear button functionality to reset everything
591
+ def clear_all():
592
+ gc.collect()
593
+ if torch.cuda.is_available():
594
+ torch.cuda.empty_cache()
595
+ return [], "", None, "*Status: Chat cleared!*", update_memory_status(), ""
596
+
597
+ clear_btn.click(
598
+ clear_all,
599
+ None,
600
+ [chatbot, audio_input, audio_preview, status, memory_status, audio_url_state],
601
+ queue=False
602
+ )
603
 
604
+ # Launch the interface
605
  if __name__ == "__main__":
606
+ # Display version warning at startup
607
+ try:
608
+ import pkg_resources
609
+ gradio_version = pkg_resources.get_distribution("gradio").version
610
+ recommended_version = "4.44.1" # Update this as needed
611
+ if gradio_version != recommended_version:
612
+ print(f"⚠️ WARNING: You are using gradio version {gradio_version}, however version {recommended_version} is available.")
613
+ print(f"⚠️ Please upgrade: pip install gradio=={recommended_version}")
614
+ except:
615
+ pass
616
+
617
+ # Launch with optimized settings
618
+ demo.launch(share=False, debug=False)