Spaces:
Sleeping
Sleeping
updates
Browse files- .hf-space +12 -0
- app.py +214 -132
- requirements.txt +17 -1
- setup_examples.py +52 -0
.hf-space
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"title": "Music Mixing Assistant",
|
| 3 |
+
"emoji": "🎧",
|
| 4 |
+
"colorFrom": "orange",
|
| 5 |
+
"colorTo": "blue",
|
| 6 |
+
"sdk": "gradio",
|
| 7 |
+
"sdk_version": "4.44.1",
|
| 8 |
+
"python_version": "3.10",
|
| 9 |
+
"app_file": "app.py",
|
| 10 |
+
"pinned": false,
|
| 11 |
+
"license": "mit"
|
| 12 |
+
}
|
app.py
CHANGED
|
@@ -10,6 +10,8 @@ from io import BytesIO
|
|
| 10 |
import logging
|
| 11 |
import sys
|
| 12 |
import gc
|
|
|
|
|
|
|
| 13 |
|
| 14 |
# Configure logging
|
| 15 |
logging.basicConfig(
|
|
@@ -57,7 +59,7 @@ def load_model():
|
|
| 57 |
# Define proper quantization config - using 4-bit quantization
|
| 58 |
quant_config = BitsAndBytesConfig(
|
| 59 |
load_in_4bit=True,
|
| 60 |
-
bnb_4bit_compute_dtype=torch.float16,
|
| 61 |
bnb_4bit_use_double_quant=True,
|
| 62 |
bnb_4bit_quant_type="nf4"
|
| 63 |
)
|
|
@@ -88,7 +90,6 @@ def load_model():
|
|
| 88 |
# Fallback to 8-bit quantization (more stable but less compression)
|
| 89 |
logger.info("Attempting 8-bit quantization fallback")
|
| 90 |
|
| 91 |
-
from transformers import BitsAndBytesConfig
|
| 92 |
quant_config_8bit = BitsAndBytesConfig(
|
| 93 |
load_in_8bit=True,
|
| 94 |
llm_int8_threshold=6.0
|
|
@@ -105,41 +106,13 @@ def load_model():
|
|
| 105 |
logger.info("Model loaded successfully with 8-bit quantization")
|
| 106 |
except Exception as e2:
|
| 107 |
logger.error(f"Error loading with 8-bit quantization: {e2}")
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
|
|
|
|
|
|
| 111 |
|
| 112 |
-
try
|
| 113 |
-
# Fallback to dummy model as last resort
|
| 114 |
-
logger.warning("Loading dummy placeholder model")
|
| 115 |
-
class DummyAudioModel:
|
| 116 |
-
def __init__(self):
|
| 117 |
-
self.device = torch.device("cpu")
|
| 118 |
-
self.dummy_parameters = [torch.tensor([0.0])]
|
| 119 |
-
|
| 120 |
-
def generate(self, **kwargs):
|
| 121 |
-
input_ids = kwargs.get("input_ids", None)
|
| 122 |
-
if input_ids is not None:
|
| 123 |
-
batch_size, seq_len = input_ids.shape
|
| 124 |
-
dummy_output = torch.ones((batch_size, seq_len + 20), dtype=torch.long)
|
| 125 |
-
dummy_output[:, :seq_len] = input_ids
|
| 126 |
-
dummy_output[:, seq_len:] = 100
|
| 127 |
-
return dummy_output
|
| 128 |
-
else:
|
| 129 |
-
return torch.ones((1, 30), dtype=torch.long)
|
| 130 |
-
|
| 131 |
-
def parameters(self):
|
| 132 |
-
return iter(self.dummy_parameters)
|
| 133 |
-
|
| 134 |
-
def to(self, device):
|
| 135 |
-
self.device = device
|
| 136 |
-
return self
|
| 137 |
-
|
| 138 |
-
model = DummyAudioModel()
|
| 139 |
-
logger.warning("Created dummy model placeholder - no real functionality available")
|
| 140 |
-
except Exception as e3:
|
| 141 |
-
logger.error(f"Failed to create dummy model: {e3}")
|
| 142 |
-
raise RuntimeError(f"Could not load any model version after multiple attempts")
|
| 143 |
|
| 144 |
# Cache the model and processor
|
| 145 |
model_cache = model
|
|
@@ -162,7 +135,6 @@ def process_audio_from_url(audio_url, processor):
|
|
| 162 |
if audio_url.startswith(('http://', 'https://')):
|
| 163 |
# For web URLs
|
| 164 |
try:
|
| 165 |
-
import requests
|
| 166 |
response = requests.get(audio_url)
|
| 167 |
response.raise_for_status()
|
| 168 |
audio_bytes = BytesIO(response.content)
|
|
@@ -214,17 +186,58 @@ def process_audio_from_url(audio_url, processor):
|
|
| 214 |
del audio_bytes
|
| 215 |
gc.collect()
|
| 216 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
@spaces.GPU(duration=120)
|
| 218 |
-
def chat_with_model(
|
| 219 |
-
"""Generate response from the model using
|
| 220 |
-
logger.info(f"Starting chat_with_model with
|
| 221 |
|
| 222 |
# Log initial memory state
|
| 223 |
log_gpu_memory("At start of chat_with_model")
|
| 224 |
|
| 225 |
-
# Validate that audio
|
| 226 |
-
if not
|
| 227 |
-
|
|
|
|
| 228 |
|
| 229 |
try:
|
| 230 |
# Load model and processor on demand
|
|
@@ -233,24 +246,18 @@ def chat_with_model(audio_url, message, chat_history):
|
|
| 233 |
# Log memory after model load
|
| 234 |
log_gpu_memory("After model load")
|
| 235 |
|
| 236 |
-
# Check if we're using a dummy model
|
| 237 |
-
is_dummy = hasattr(model, '__class__') and model.__class__.__name__ == 'DummyAudioModel'
|
| 238 |
-
|
| 239 |
-
if is_dummy:
|
| 240 |
-
logger.warning("Using dummy model - providing generic response")
|
| 241 |
-
return (
|
| 242 |
-
"⚠️ I'm currently having trouble analyzing your audio due to technical limitations "
|
| 243 |
-
"in this environment. The model requires more GPU memory than is available. "
|
| 244 |
-
"Please try a different audio file or contact the developer for assistance."
|
| 245 |
-
)
|
| 246 |
-
|
| 247 |
# Process audio
|
| 248 |
audios = []
|
| 249 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
if audio_data is not None:
|
| 251 |
audios.append(audio_data)
|
| 252 |
else:
|
| 253 |
-
return "⚠️ Failed to process audio from the provided
|
| 254 |
|
| 255 |
# Log memory after audio processing
|
| 256 |
log_gpu_memory("After audio processing")
|
|
@@ -281,7 +288,7 @@ def chat_with_model(audio_url, message, chat_history):
|
|
| 281 |
conversation.append({
|
| 282 |
"role": "user",
|
| 283 |
"content": [
|
| 284 |
-
{"type": "audio", "audio_url":
|
| 285 |
{"type": "text", "text": message}
|
| 286 |
]
|
| 287 |
})
|
|
@@ -328,21 +335,28 @@ def chat_with_model(audio_url, message, chat_history):
|
|
| 328 |
log_gpu_memory("After input preparation")
|
| 329 |
except Exception as e:
|
| 330 |
logger.error(f"Error generating model inputs: {e}")
|
| 331 |
-
return f"⚠️ Error
|
| 332 |
|
| 333 |
# Generate response from model
|
| 334 |
with torch.no_grad():
|
| 335 |
try:
|
| 336 |
generate_ids = model.generate(
|
| 337 |
**inputs,
|
| 338 |
-
max_new_tokens=128,
|
| 339 |
temperature=0.7,
|
| 340 |
do_sample=True,
|
| 341 |
top_p=0.9,
|
| 342 |
-
use_cache=True
|
| 343 |
)
|
| 344 |
logger.info(f"Response generated successfully")
|
| 345 |
log_gpu_memory("After generation")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 346 |
except Exception as e:
|
| 347 |
logger.error(f"Error during model.generate: {e}")
|
| 348 |
return f"⚠️ Model generation error: {str(e)}"
|
|
@@ -360,7 +374,7 @@ def chat_with_model(audio_url, message, chat_history):
|
|
| 360 |
# Quick validation of response
|
| 361 |
if not response or response.isspace():
|
| 362 |
logger.error("Empty response received from model")
|
| 363 |
-
return "⚠️ Model returned an empty response. Please try again."
|
| 364 |
|
| 365 |
# Clean up memory
|
| 366 |
del inputs, generate_ids
|
|
@@ -373,6 +387,9 @@ def chat_with_model(audio_url, message, chat_history):
|
|
| 373 |
logger.error(f"Error decoding response: {e}")
|
| 374 |
return f"⚠️ Error decoding response: {str(e)}"
|
| 375 |
|
|
|
|
|
|
|
|
|
|
| 376 |
except Exception as e:
|
| 377 |
logger.error(f"Unexpected error in chat_with_model: {e}", exc_info=True)
|
| 378 |
return f"⚠️ An error occurred: {str(e)}"
|
|
@@ -383,14 +400,6 @@ def chat_with_model(audio_url, message, chat_history):
|
|
| 383 |
torch.cuda.empty_cache()
|
| 384 |
log_gpu_memory("End of chat_with_model")
|
| 385 |
|
| 386 |
-
# Function to check if URL is a valid audio file
|
| 387 |
-
def is_valid_audio_url(url):
|
| 388 |
-
if not url or not url.strip():
|
| 389 |
-
return False
|
| 390 |
-
|
| 391 |
-
url = url.strip().lower()
|
| 392 |
-
return url.endswith(('.wav', '.mp3', '.ogg', '.flac', '.m4a'))
|
| 393 |
-
|
| 394 |
# Custom theme with orange primary color and dark background
|
| 395 |
orange_black_theme = gr.themes.Base(
|
| 396 |
primary_hue="orange",
|
|
@@ -425,6 +434,20 @@ button.primary {
|
|
| 425 |
.message.bot {
|
| 426 |
background-color: var(--dark-bg) !important;
|
| 427 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 428 |
"""
|
| 429 |
|
| 430 |
# Gradio interface
|
|
@@ -432,27 +455,28 @@ with gr.Blocks(theme=orange_black_theme, css=custom_css) as demo:
|
|
| 432 |
gr.Markdown(
|
| 433 |
"""
|
| 434 |
# 🎧 Music Mixing Assistant
|
| 435 |
-
|
| 436 |
|
| 437 |
-
|
| 438 |
*(Note: Audio samples are limited to 15 seconds for optimal performance)*
|
| 439 |
"""
|
| 440 |
)
|
| 441 |
|
| 442 |
-
#
|
| 443 |
-
|
|
|
|
|
|
|
| 444 |
|
| 445 |
with gr.Row():
|
| 446 |
with gr.Column(scale=3):
|
| 447 |
# Chat interface with customized settings
|
| 448 |
chatbot = gr.Chatbot(
|
| 449 |
height=500,
|
| 450 |
-
avatar_images=(None, "🎧"),
|
| 451 |
show_label=False,
|
| 452 |
container=True,
|
| 453 |
bubble_full_width=False,
|
| 454 |
-
show_copy_button=
|
| 455 |
-
show_share_button=False, # Removed share button
|
| 456 |
render_markdown=True
|
| 457 |
)
|
| 458 |
|
|
@@ -471,23 +495,43 @@ with gr.Blocks(theme=orange_black_theme, css=custom_css) as demo:
|
|
| 471 |
clear_btn = gr.Button("Clear Chat", variant="secondary")
|
| 472 |
|
| 473 |
with gr.Column(scale=1):
|
| 474 |
-
#
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
|
| 481 |
-
|
| 482 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 483 |
|
| 484 |
-
# Preview player
|
| 485 |
audio_preview = gr.Audio(
|
| 486 |
-
label="Audio Preview
|
| 487 |
interactive=False,
|
| 488 |
visible=True
|
| 489 |
)
|
| 490 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 491 |
# Memory usage indicator
|
| 492 |
if torch.cuda.is_available():
|
| 493 |
memory_status = gr.Markdown("*GPU Memory: Initializing...*")
|
|
@@ -505,45 +549,70 @@ with gr.Blocks(theme=orange_black_theme, css=custom_css) as demo:
|
|
| 505 |
# Display status
|
| 506 |
status = gr.Markdown("*Status: Ready to assist with your mix!*")
|
| 507 |
|
| 508 |
-
# Function to update
|
| 509 |
-
def
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 513 |
|
| 514 |
-
# Try to provide a preview
|
| 515 |
try:
|
| 516 |
-
return url, gr.update(value=url), "*Status: Audio track set! First 15 seconds will be analyzed.*", update_memory_status()
|
| 517 |
except Exception as e:
|
| 518 |
-
|
| 519 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 520 |
|
| 521 |
# Function to clear chat
|
| 522 |
def clear_chat():
|
| 523 |
return []
|
| 524 |
|
| 525 |
-
#
|
| 526 |
-
def
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
# Then return the values including an empty chat
|
| 530 |
-
return result[0], result[1], [], result[2], result[3]
|
| 531 |
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
|
| 535 |
-
|
| 536 |
-
|
| 537 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 538 |
|
| 539 |
# Handle submit button
|
| 540 |
-
def respond(
|
| 541 |
if not message.strip():
|
| 542 |
return chat_history, "*Status: Please enter a message*", update_memory_status()
|
| 543 |
|
| 544 |
-
#
|
| 545 |
-
if
|
| 546 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 547 |
chat_history.append((message, f"⚠️ {error_msg}"))
|
| 548 |
return chat_history, f"*Status: {error_msg}*", update_memory_status()
|
| 549 |
|
|
@@ -553,66 +622,79 @@ with gr.Blocks(theme=orange_black_theme, css=custom_css) as demo:
|
|
| 553 |
|
| 554 |
try:
|
| 555 |
# Process and get response
|
| 556 |
-
bot_message = chat_with_model(
|
| 557 |
|
| 558 |
# Update the last message with the bot's response
|
| 559 |
chat_history[-1] = (message, bot_message)
|
| 560 |
|
| 561 |
# Return updated chat history
|
| 562 |
-
yield chat_history, "*Status: Ready
|
| 563 |
except Exception as e:
|
| 564 |
error_msg = f"Error generating response: {str(e)}"
|
| 565 |
chat_history[-1] = (message, f"⚠️ {error_msg}")
|
| 566 |
yield chat_history, f"*Status: {error_msg}*", update_memory_status()
|
| 567 |
|
| 568 |
# Handle submit with clear input
|
| 569 |
-
def respond_and_clear_input(
|
| 570 |
# First get response updates
|
| 571 |
-
for result in respond(
|
| 572 |
# Yield each result with empty message input
|
| 573 |
yield result[0], result[1], result[2], ""
|
| 574 |
|
| 575 |
# Connect UI components
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 576 |
submit_btn.click(
|
| 577 |
respond_and_clear_input,
|
| 578 |
-
inputs=[
|
| 579 |
outputs=[chatbot, status, memory_status, message],
|
| 580 |
queue=True
|
| 581 |
)
|
| 582 |
|
| 583 |
message.submit(
|
| 584 |
respond_and_clear_input,
|
| 585 |
-
inputs=[
|
| 586 |
outputs=[chatbot, status, memory_status, message],
|
| 587 |
queue=True
|
| 588 |
)
|
| 589 |
|
| 590 |
-
# Clear button functionality
|
| 591 |
def clear_all():
|
| 592 |
gc.collect()
|
| 593 |
if torch.cuda.is_available():
|
| 594 |
torch.cuda.empty_cache()
|
| 595 |
-
return [], "", None, "*Status: Chat cleared!*", update_memory_status(), ""
|
| 596 |
|
| 597 |
clear_btn.click(
|
| 598 |
clear_all,
|
| 599 |
None,
|
| 600 |
-
[chatbot, audio_input, audio_preview, status, memory_status,
|
| 601 |
queue=False
|
| 602 |
)
|
| 603 |
|
| 604 |
# Launch the interface
|
| 605 |
if __name__ == "__main__":
|
| 606 |
-
# Display version warning at startup
|
| 607 |
-
try:
|
| 608 |
-
import pkg_resources
|
| 609 |
-
gradio_version = pkg_resources.get_distribution("gradio").version
|
| 610 |
-
recommended_version = "4.44.1" # Update this as needed
|
| 611 |
-
if gradio_version != recommended_version:
|
| 612 |
-
print(f"⚠️ WARNING: You are using gradio version {gradio_version}, however version {recommended_version} is available.")
|
| 613 |
-
print(f"⚠️ Please upgrade: pip install gradio=={recommended_version}")
|
| 614 |
-
except:
|
| 615 |
-
pass
|
| 616 |
-
|
| 617 |
# Launch with optimized settings
|
| 618 |
demo.launch(share=False, debug=False)
|
|
|
|
| 10 |
import logging
|
| 11 |
import sys
|
| 12 |
import gc
|
| 13 |
+
from tempfile import NamedTemporaryFile
|
| 14 |
+
import requests
|
| 15 |
|
| 16 |
# Configure logging
|
| 17 |
logging.basicConfig(
|
|
|
|
| 59 |
# Define proper quantization config - using 4-bit quantization
|
| 60 |
quant_config = BitsAndBytesConfig(
|
| 61 |
load_in_4bit=True,
|
| 62 |
+
bnb_4bit_compute_dtype=torch.float16,
|
| 63 |
bnb_4bit_use_double_quant=True,
|
| 64 |
bnb_4bit_quant_type="nf4"
|
| 65 |
)
|
|
|
|
| 90 |
# Fallback to 8-bit quantization (more stable but less compression)
|
| 91 |
logger.info("Attempting 8-bit quantization fallback")
|
| 92 |
|
|
|
|
| 93 |
quant_config_8bit = BitsAndBytesConfig(
|
| 94 |
load_in_8bit=True,
|
| 95 |
llm_int8_threshold=6.0
|
|
|
|
| 106 |
logger.info("Model loaded successfully with 8-bit quantization")
|
| 107 |
except Exception as e2:
|
| 108 |
logger.error(f"Error loading with 8-bit quantization: {e2}")
|
| 109 |
+
# Create a more user-friendly error message for the UI
|
| 110 |
+
class ModelLoadError(Exception):
|
| 111 |
+
def __init__(self, message="Failed to load model due to memory limitations"):
|
| 112 |
+
self.message = message
|
| 113 |
+
super().__init__(self.message)
|
| 114 |
|
| 115 |
+
raise ModelLoadError("Model could not be loaded due to memory constraints. Please try again later.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
|
| 117 |
# Cache the model and processor
|
| 118 |
model_cache = model
|
|
|
|
| 135 |
if audio_url.startswith(('http://', 'https://')):
|
| 136 |
# For web URLs
|
| 137 |
try:
|
|
|
|
| 138 |
response = requests.get(audio_url)
|
| 139 |
response.raise_for_status()
|
| 140 |
audio_bytes = BytesIO(response.content)
|
|
|
|
| 186 |
del audio_bytes
|
| 187 |
gc.collect()
|
| 188 |
|
| 189 |
+
def process_uploaded_audio(audio_file, processor):
|
| 190 |
+
"""Process uploaded audio file for model input"""
|
| 191 |
+
try:
|
| 192 |
+
logger.info("Processing uploaded audio file")
|
| 193 |
+
# Get processor's sampling rate
|
| 194 |
+
target_sr = int(processor.feature_extractor.sampling_rate)
|
| 195 |
+
|
| 196 |
+
# Handle different Gradio Audio component return types
|
| 197 |
+
if isinstance(audio_file, tuple) and len(audio_file) == 2:
|
| 198 |
+
# Gradio returns (path, sr) tuple
|
| 199 |
+
temp_path, sr_loaded = audio_file
|
| 200 |
+
logger.info(f"Uploaded audio path: {temp_path}, SR: {sr_loaded}")
|
| 201 |
+
|
| 202 |
+
# Load the audio
|
| 203 |
+
audio_data, _ = librosa.load(temp_path, sr=target_sr)
|
| 204 |
+
else:
|
| 205 |
+
# Handle other cases (could be a file path string)
|
| 206 |
+
logger.info(f"Uploaded audio file type: {type(audio_file)}")
|
| 207 |
+
audio_data, sr_loaded = librosa.load(audio_file, sr=None)
|
| 208 |
+
|
| 209 |
+
# Resample if needed
|
| 210 |
+
if sr_loaded != target_sr:
|
| 211 |
+
logger.info(f"Resampling from {sr_loaded} Hz to {target_sr} Hz")
|
| 212 |
+
audio_data = librosa.resample(audio_data, orig_sr=sr_loaded, target_sr=target_sr)
|
| 213 |
+
|
| 214 |
+
# Reduce to 15 seconds maximum
|
| 215 |
+
max_seconds = 15
|
| 216 |
+
max_samples = max_seconds * target_sr
|
| 217 |
+
if len(audio_data) > max_samples:
|
| 218 |
+
logger.info(f"Limiting audio to {max_seconds} seconds for memory efficiency")
|
| 219 |
+
audio_data = audio_data[:max_samples]
|
| 220 |
+
|
| 221 |
+
# Ensure audio is float32
|
| 222 |
+
audio_data = audio_data.astype(np.float32)
|
| 223 |
+
|
| 224 |
+
return audio_data
|
| 225 |
+
except Exception as e:
|
| 226 |
+
logger.error(f"Error processing uploaded audio: {e}", exc_info=True)
|
| 227 |
+
return None
|
| 228 |
+
|
| 229 |
@spaces.GPU(duration=120)
|
| 230 |
+
def chat_with_model(audio_source, audio_data_type, message, chat_history):
|
| 231 |
+
"""Generate response from the model using audio"""
|
| 232 |
+
logger.info(f"Starting chat_with_model with audio_source: {audio_source}, type: {audio_data_type}, message: {message}")
|
| 233 |
|
| 234 |
# Log initial memory state
|
| 235 |
log_gpu_memory("At start of chat_with_model")
|
| 236 |
|
| 237 |
+
# Validate that audio source is provided
|
| 238 |
+
if (audio_data_type == "url" and (not audio_source or not audio_source.strip())) or \
|
| 239 |
+
(audio_data_type == "upload" and not audio_source):
|
| 240 |
+
return "⚠️ Please provide an audio source (URL or upload) before chatting."
|
| 241 |
|
| 242 |
try:
|
| 243 |
# Load model and processor on demand
|
|
|
|
| 246 |
# Log memory after model load
|
| 247 |
log_gpu_memory("After model load")
|
| 248 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 249 |
# Process audio
|
| 250 |
audios = []
|
| 251 |
+
|
| 252 |
+
if audio_data_type == "url":
|
| 253 |
+
audio_data = process_audio_from_url(audio_source, processor)
|
| 254 |
+
else: # audio_data_type == "upload"
|
| 255 |
+
audio_data = process_uploaded_audio(audio_source, processor)
|
| 256 |
+
|
| 257 |
if audio_data is not None:
|
| 258 |
audios.append(audio_data)
|
| 259 |
else:
|
| 260 |
+
return "⚠️ Failed to process audio from the provided source. Please check that the audio file is valid and accessible."
|
| 261 |
|
| 262 |
# Log memory after audio processing
|
| 263 |
log_gpu_memory("After audio processing")
|
|
|
|
| 288 |
conversation.append({
|
| 289 |
"role": "user",
|
| 290 |
"content": [
|
| 291 |
+
{"type": "audio", "audio_url": "placeholder_for_processed_audio"},
|
| 292 |
{"type": "text", "text": message}
|
| 293 |
]
|
| 294 |
})
|
|
|
|
| 335 |
log_gpu_memory("After input preparation")
|
| 336 |
except Exception as e:
|
| 337 |
logger.error(f"Error generating model inputs: {e}")
|
| 338 |
+
return f"⚠️ Error preparing audio for analysis: {str(e)}"
|
| 339 |
|
| 340 |
# Generate response from model
|
| 341 |
with torch.no_grad():
|
| 342 |
try:
|
| 343 |
generate_ids = model.generate(
|
| 344 |
**inputs,
|
| 345 |
+
max_new_tokens=128,
|
| 346 |
temperature=0.7,
|
| 347 |
do_sample=True,
|
| 348 |
top_p=0.9,
|
| 349 |
+
use_cache=True
|
| 350 |
)
|
| 351 |
logger.info(f"Response generated successfully")
|
| 352 |
log_gpu_memory("After generation")
|
| 353 |
+
except RuntimeError as e:
|
| 354 |
+
if "CUDA out of memory" in str(e):
|
| 355 |
+
logger.error(f"CUDA OOM during generation: {e}")
|
| 356 |
+
return "⚠️ Insufficient GPU memory to analyze this audio. Please try with a simpler or shorter audio clip."
|
| 357 |
+
else:
|
| 358 |
+
logger.error(f"Error during model.generate: {e}")
|
| 359 |
+
return f"⚠️ Model generation error: {str(e)}"
|
| 360 |
except Exception as e:
|
| 361 |
logger.error(f"Error during model.generate: {e}")
|
| 362 |
return f"⚠️ Model generation error: {str(e)}"
|
|
|
|
| 374 |
# Quick validation of response
|
| 375 |
if not response or response.isspace():
|
| 376 |
logger.error("Empty response received from model")
|
| 377 |
+
return "⚠️ Model returned an empty response. Please try again with a different question or audio file."
|
| 378 |
|
| 379 |
# Clean up memory
|
| 380 |
del inputs, generate_ids
|
|
|
|
| 387 |
logger.error(f"Error decoding response: {e}")
|
| 388 |
return f"⚠️ Error decoding response: {str(e)}"
|
| 389 |
|
| 390 |
+
except ModelLoadError as e:
|
| 391 |
+
logger.error(f"Model load error: {e}")
|
| 392 |
+
return str(e)
|
| 393 |
except Exception as e:
|
| 394 |
logger.error(f"Unexpected error in chat_with_model: {e}", exc_info=True)
|
| 395 |
return f"⚠️ An error occurred: {str(e)}"
|
|
|
|
| 400 |
torch.cuda.empty_cache()
|
| 401 |
log_gpu_memory("End of chat_with_model")
|
| 402 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 403 |
# Custom theme with orange primary color and dark background
|
| 404 |
orange_black_theme = gr.themes.Base(
|
| 405 |
primary_hue="orange",
|
|
|
|
| 434 |
.message.bot {
|
| 435 |
background-color: var(--dark-bg) !important;
|
| 436 |
}
|
| 437 |
+
|
| 438 |
+
.error-message {
|
| 439 |
+
color: #ff4d4d;
|
| 440 |
+
font-weight: bold;
|
| 441 |
+
padding: 8px;
|
| 442 |
+
border-radius: 4px;
|
| 443 |
+
background-color: rgba(255, 77, 77, 0.1);
|
| 444 |
+
margin-top: 8px;
|
| 445 |
+
}
|
| 446 |
+
|
| 447 |
+
.processing-indicator {
|
| 448 |
+
color: var(--orange-primary);
|
| 449 |
+
font-style: italic;
|
| 450 |
+
}
|
| 451 |
"""
|
| 452 |
|
| 453 |
# Gradio interface
|
|
|
|
| 455 |
gr.Markdown(
|
| 456 |
"""
|
| 457 |
# 🎧 Music Mixing Assistant
|
| 458 |
+
Get professional feedback on your music production and mixing!
|
| 459 |
|
| 460 |
+
Enter an audio URL or upload a file, then chat with your AI mixing engineer.
|
| 461 |
*(Note: Audio samples are limited to 15 seconds for optimal performance)*
|
| 462 |
"""
|
| 463 |
)
|
| 464 |
|
| 465 |
+
# State variables
|
| 466 |
+
audio_source_state = gr.State("")
|
| 467 |
+
audio_type_state = gr.State("url") # "url" or "upload"
|
| 468 |
+
uploaded_audio_state = gr.State(None)
|
| 469 |
|
| 470 |
with gr.Row():
|
| 471 |
with gr.Column(scale=3):
|
| 472 |
# Chat interface with customized settings
|
| 473 |
chatbot = gr.Chatbot(
|
| 474 |
height=500,
|
| 475 |
+
avatar_images=(None, "🎧"),
|
| 476 |
show_label=False,
|
| 477 |
container=True,
|
| 478 |
bubble_full_width=False,
|
| 479 |
+
show_copy_button=True,
|
|
|
|
| 480 |
render_markdown=True
|
| 481 |
)
|
| 482 |
|
|
|
|
| 495 |
clear_btn = gr.Button("Clear Chat", variant="secondary")
|
| 496 |
|
| 497 |
with gr.Column(scale=1):
|
| 498 |
+
# Tabs for different input methods
|
| 499 |
+
with gr.Tabs():
|
| 500 |
+
with gr.TabItem("URL Input"):
|
| 501 |
+
# Audio URL input
|
| 502 |
+
audio_input = gr.Textbox(
|
| 503 |
+
label="Audio URL",
|
| 504 |
+
placeholder="https://example.com/your-audio-file.wav",
|
| 505 |
+
info="Enter URL to a WAV/MP3/OGG audio file"
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
# Set URL button
|
| 509 |
+
set_url_btn = gr.Button("Set Audio From URL", variant="primary")
|
| 510 |
+
|
| 511 |
+
with gr.TabItem("File Upload"):
|
| 512 |
+
# Audio upload component
|
| 513 |
+
audio_upload = gr.Audio(
|
| 514 |
+
label="Upload Audio File",
|
| 515 |
+
type="filepath",
|
| 516 |
+
format="mp3",
|
| 517 |
+
info="Upload WAV/MP3/OGG (15 sec max will be analyzed)"
|
| 518 |
+
)
|
| 519 |
+
|
| 520 |
+
# Set upload button
|
| 521 |
+
set_upload_btn = gr.Button("Set Uploaded Audio", variant="primary")
|
| 522 |
|
| 523 |
+
# Preview player
|
| 524 |
audio_preview = gr.Audio(
|
| 525 |
+
label="Audio Preview",
|
| 526 |
interactive=False,
|
| 527 |
visible=True
|
| 528 |
)
|
| 529 |
|
| 530 |
+
# Example audio files section
|
| 531 |
+
gr.Markdown("### Try an example:")
|
| 532 |
+
example_btn_1 = gr.Button("Example: Guitar Mix")
|
| 533 |
+
example_btn_2 = gr.Button("Example: Vocals Track")
|
| 534 |
+
|
| 535 |
# Memory usage indicator
|
| 536 |
if torch.cuda.is_available():
|
| 537 |
memory_status = gr.Markdown("*GPU Memory: Initializing...*")
|
|
|
|
| 549 |
# Display status
|
| 550 |
status = gr.Markdown("*Status: Ready to assist with your mix!*")
|
| 551 |
|
| 552 |
+
# Function to update from URL
|
| 553 |
+
def update_from_url(url):
|
| 554 |
+
if not url or not url.strip():
|
| 555 |
+
return "", gr.update(value=None), "url", None, "*Status: No URL provided*", update_memory_status()
|
| 556 |
+
|
| 557 |
+
# Basic validation - accept more formats
|
| 558 |
+
valid_extensions = ('.wav', '.mp3', '.ogg', '.flac', '.m4a')
|
| 559 |
+
if not any(url.lower().endswith(ext) for ext in valid_extensions):
|
| 560 |
+
return "", gr.update(value=None), "url", None, "*Status: Invalid audio URL format*", update_memory_status()
|
| 561 |
|
| 562 |
+
# Try to provide a preview
|
| 563 |
try:
|
| 564 |
+
return url, gr.update(value=url), "url", None, "*Status: Audio track set from URL! First 15 seconds will be analyzed.*", update_memory_status()
|
| 565 |
except Exception as e:
|
| 566 |
+
return url, gr.update(value=None), "url", None, f"*Status: Audio URL set, but preview failed: {str(e)}*", update_memory_status()
|
| 567 |
+
|
| 568 |
+
# Function to update from upload
|
| 569 |
+
def update_from_upload(audio_file):
|
| 570 |
+
if audio_file is None:
|
| 571 |
+
return "", gr.update(value=None), "upload", None, "*Status: No file uploaded*", update_memory_status()
|
| 572 |
+
|
| 573 |
+
try:
|
| 574 |
+
# Store the uploaded file and update the preview
|
| 575 |
+
return "", gr.update(value=audio_file), "upload", audio_file, "*Status: Audio track set from upload! First 15 seconds will be analyzed.*", update_memory_status()
|
| 576 |
+
except Exception as e:
|
| 577 |
+
return "", gr.update(value=None), "upload", None, f"*Status: Upload failed: {str(e)}*", update_memory_status()
|
| 578 |
|
| 579 |
# Function to clear chat
|
| 580 |
def clear_chat():
|
| 581 |
return []
|
| 582 |
|
| 583 |
+
# Update states and clear chat when setting audio
|
| 584 |
+
def update_url_and_clear(url):
|
| 585 |
+
audio_source, preview, audio_type, upload_file, status_msg, memory_msg = update_from_url(url)
|
| 586 |
+
return audio_source, preview, audio_type, upload_file, [], status_msg, memory_msg
|
|
|
|
|
|
|
| 587 |
|
| 588 |
+
def update_upload_and_clear(audio_file):
|
| 589 |
+
audio_source, preview, audio_type, upload_file, status_msg, memory_msg = update_from_upload(audio_file)
|
| 590 |
+
return audio_source, preview, audio_type, upload_file, [], status_msg, memory_msg
|
| 591 |
+
|
| 592 |
+
# Load example audio
|
| 593 |
+
def load_example(example_num):
|
| 594 |
+
if example_num == 1:
|
| 595 |
+
# Example 1: Guitar Mix
|
| 596 |
+
example_url = "https://huggingface.co/spaces/mclemcrew/audio-mix-assistant/resolve/main/examples/guitar_mix_example.mp3"
|
| 597 |
+
else:
|
| 598 |
+
# Example 2: Vocals Track
|
| 599 |
+
example_url = "https://huggingface.co/spaces/mclemcrew/audio-mix-assistant/resolve/main/examples/vocals_example.mp3"
|
| 600 |
+
|
| 601 |
+
audio_source, preview, audio_type, upload_file, status_msg, memory_msg = update_from_url(example_url)
|
| 602 |
+
return example_url, audio_source, preview, audio_type, upload_file, [], status_msg, memory_msg
|
| 603 |
|
| 604 |
# Handle submit button
|
| 605 |
+
def respond(audio_source, audio_type, uploaded_audio, message, chat_history):
|
| 606 |
if not message.strip():
|
| 607 |
return chat_history, "*Status: Please enter a message*", update_memory_status()
|
| 608 |
|
| 609 |
+
# Determine the actual audio source to use
|
| 610 |
+
actual_audio_source = uploaded_audio if audio_type == "upload" else audio_source
|
| 611 |
+
|
| 612 |
+
# Check if audio source is set
|
| 613 |
+
if (audio_type == "url" and (not audio_source or not audio_source.strip())) or \
|
| 614 |
+
(audio_type == "upload" and uploaded_audio is None):
|
| 615 |
+
error_msg = "No audio track set. Please set an audio URL or upload a file first."
|
| 616 |
chat_history.append((message, f"⚠️ {error_msg}"))
|
| 617 |
return chat_history, f"*Status: {error_msg}*", update_memory_status()
|
| 618 |
|
|
|
|
| 622 |
|
| 623 |
try:
|
| 624 |
# Process and get response
|
| 625 |
+
bot_message = chat_with_model(actual_audio_source, audio_type, message, chat_history[:-1])
|
| 626 |
|
| 627 |
# Update the last message with the bot's response
|
| 628 |
chat_history[-1] = (message, bot_message)
|
| 629 |
|
| 630 |
# Return updated chat history
|
| 631 |
+
yield chat_history, "*Status: Ready for your next question!*", update_memory_status()
|
| 632 |
except Exception as e:
|
| 633 |
error_msg = f"Error generating response: {str(e)}"
|
| 634 |
chat_history[-1] = (message, f"⚠️ {error_msg}")
|
| 635 |
yield chat_history, f"*Status: {error_msg}*", update_memory_status()
|
| 636 |
|
| 637 |
# Handle submit with clear input
|
| 638 |
+
def respond_and_clear_input(audio_source, audio_type, uploaded_audio, message, chat_history):
|
| 639 |
# First get response updates
|
| 640 |
+
for result in respond(audio_source, audio_type, uploaded_audio, message, chat_history):
|
| 641 |
# Yield each result with empty message input
|
| 642 |
yield result[0], result[1], result[2], ""
|
| 643 |
|
| 644 |
# Connect UI components
|
| 645 |
+
set_url_btn.click(
|
| 646 |
+
update_url_and_clear,
|
| 647 |
+
inputs=[audio_input],
|
| 648 |
+
outputs=[audio_source_state, audio_preview, audio_type_state, uploaded_audio_state, chatbot, status, memory_status]
|
| 649 |
+
)
|
| 650 |
+
|
| 651 |
+
set_upload_btn.click(
|
| 652 |
+
update_upload_and_clear,
|
| 653 |
+
inputs=[audio_upload],
|
| 654 |
+
outputs=[audio_source_state, audio_preview, audio_type_state, uploaded_audio_state, chatbot, status, memory_status]
|
| 655 |
+
)
|
| 656 |
+
|
| 657 |
+
example_btn_1.click(
|
| 658 |
+
lambda: load_example(1),
|
| 659 |
+
inputs=[],
|
| 660 |
+
outputs=[audio_input, audio_source_state, audio_preview, audio_type_state, uploaded_audio_state, chatbot, status, memory_status]
|
| 661 |
+
)
|
| 662 |
+
|
| 663 |
+
example_btn_2.click(
|
| 664 |
+
lambda: load_example(2),
|
| 665 |
+
inputs=[],
|
| 666 |
+
outputs=[audio_input, audio_source_state, audio_preview, audio_type_state, uploaded_audio_state, chatbot, status, memory_status]
|
| 667 |
+
)
|
| 668 |
+
|
| 669 |
submit_btn.click(
|
| 670 |
respond_and_clear_input,
|
| 671 |
+
inputs=[audio_source_state, audio_type_state, uploaded_audio_state, message, chatbot],
|
| 672 |
outputs=[chatbot, status, memory_status, message],
|
| 673 |
queue=True
|
| 674 |
)
|
| 675 |
|
| 676 |
message.submit(
|
| 677 |
respond_and_clear_input,
|
| 678 |
+
inputs=[audio_source_state, audio_type_state, uploaded_audio_state, message, chatbot],
|
| 679 |
outputs=[chatbot, status, memory_status, message],
|
| 680 |
queue=True
|
| 681 |
)
|
| 682 |
|
| 683 |
+
# Clear button functionality
|
| 684 |
def clear_all():
|
| 685 |
gc.collect()
|
| 686 |
if torch.cuda.is_available():
|
| 687 |
torch.cuda.empty_cache()
|
| 688 |
+
return [], "", None, "*Status: Chat cleared!*", update_memory_status(), "", "url", None
|
| 689 |
|
| 690 |
clear_btn.click(
|
| 691 |
clear_all,
|
| 692 |
None,
|
| 693 |
+
[chatbot, audio_input, audio_preview, status, memory_status, audio_source_state, audio_type_state, uploaded_audio_state],
|
| 694 |
queue=False
|
| 695 |
)
|
| 696 |
|
| 697 |
# Launch the interface
|
| 698 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 699 |
# Launch with optimized settings
|
| 700 |
demo.launch(share=False, debug=False)
|
requirements.txt
CHANGED
|
@@ -1 +1,17 @@
|
|
| 1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio==4.44.1
|
| 2 |
+
torch>=2.0.0
|
| 3 |
+
transformers>=4.38.0
|
| 4 |
+
accelerate>=0.25.0
|
| 5 |
+
bitsandbytes>=0.41.0
|
| 6 |
+
librosa>=0.10.0
|
| 7 |
+
numpy>=1.24.0
|
| 8 |
+
requests>=2.28.0
|
| 9 |
+
scipy>=1.10.0
|
| 10 |
+
tqdm>=4.65.0
|
| 11 |
+
huggingface_hub>=0.17.0
|
| 12 |
+
sentencepiece>=0.1.97
|
| 13 |
+
soundfile>=0.12.1
|
| 14 |
+
packaging>=21.0
|
| 15 |
+
peft>=0.4.0
|
| 16 |
+
audiomentations>=0.29.0
|
| 17 |
+
datasets>=2.12.0
|
setup_examples.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import requests
|
| 3 |
+
import logging
|
| 4 |
+
|
| 5 |
+
# Configure logging
|
| 6 |
+
logging.basicConfig(
|
| 7 |
+
level=logging.INFO,
|
| 8 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
| 9 |
+
)
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
def setup_examples():
|
| 13 |
+
"""Download example audio files for the app"""
|
| 14 |
+
# Create examples directory if it doesn't exist
|
| 15 |
+
examples_dir = "examples"
|
| 16 |
+
os.makedirs(examples_dir, exist_ok=True)
|
| 17 |
+
|
| 18 |
+
# Example files to download - you can replace these with your own examples
|
| 19 |
+
examples = [
|
| 20 |
+
{
|
| 21 |
+
"name": "guitar_mix_example.mp3",
|
| 22 |
+
"url": "https://freesound.org/data/previews/612/612850_5674468-lq.mp3" # Guitar example from freesound
|
| 23 |
+
},
|
| 24 |
+
{
|
| 25 |
+
"name": "vocals_example.mp3",
|
| 26 |
+
"url": "https://freesound.org/data/previews/336/336590_5674468-lq.mp3" # Vocal example from freesound
|
| 27 |
+
}
|
| 28 |
+
]
|
| 29 |
+
|
| 30 |
+
# Download each example
|
| 31 |
+
for example in examples:
|
| 32 |
+
file_path = os.path.join(examples_dir, example["name"])
|
| 33 |
+
|
| 34 |
+
# Skip if file already exists
|
| 35 |
+
if os.path.exists(file_path):
|
| 36 |
+
logger.info(f"File {example['name']} already exists, skipping download")
|
| 37 |
+
continue
|
| 38 |
+
|
| 39 |
+
try:
|
| 40 |
+
logger.info(f"Downloading {example['name']} from {example['url']}")
|
| 41 |
+
response = requests.get(example["url"])
|
| 42 |
+
response.raise_for_status()
|
| 43 |
+
|
| 44 |
+
with open(file_path, "wb") as f:
|
| 45 |
+
f.write(response.content)
|
| 46 |
+
|
| 47 |
+
logger.info(f"Successfully downloaded {example['name']}")
|
| 48 |
+
except Exception as e:
|
| 49 |
+
logger.error(f"Error downloading {example['name']}: {e}")
|
| 50 |
+
|
| 51 |
+
if __name__ == "__main__":
|
| 52 |
+
setup_examples()
|