Spaces:
Sleeping
Sleeping
updates
Browse files- app.py +262 -459
- requirements.txt +2 -2
- setup_examples.py +0 -52
app.py
CHANGED
|
@@ -2,14 +2,16 @@ import os
|
|
| 2 |
import gradio as gr
|
| 3 |
import spaces
|
| 4 |
import torch
|
| 5 |
-
from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration, BitsAndBytesConfig
|
| 6 |
import numpy as np
|
| 7 |
import librosa
|
| 8 |
-
from urllib.request import urlopen
|
| 9 |
-
from io import BytesIO
|
| 10 |
import logging
|
| 11 |
import sys
|
| 12 |
import gc
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
# Configure logging
|
| 15 |
logging.basicConfig(
|
|
@@ -19,19 +21,22 @@ logging.basicConfig(
|
|
| 19 |
)
|
| 20 |
logger = logging.getLogger(__name__)
|
| 21 |
|
| 22 |
-
#
|
| 23 |
-
MODEL_ID = "mclemcrew/
|
| 24 |
|
| 25 |
-
# Cache for model and processor
|
| 26 |
model_cache = None
|
| 27 |
processor_cache = None
|
| 28 |
|
| 29 |
-
# Memory tracking
|
| 30 |
def log_gpu_memory(message=""):
|
|
|
|
| 31 |
if torch.cuda.is_available():
|
| 32 |
allocated = torch.cuda.memory_allocated() / 1024**3
|
| 33 |
reserved = torch.cuda.memory_reserved() / 1024**3
|
| 34 |
logger.info(f"{message} - GPU Memory: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved")
|
|
|
|
|
|
|
| 35 |
|
| 36 |
def load_model():
|
| 37 |
"""Load the fine-tuned model with optimized memory usage"""
|
|
@@ -45,50 +50,51 @@ def load_model():
|
|
| 45 |
# Log initial GPU state
|
| 46 |
log_gpu_memory("Before model loading")
|
| 47 |
|
| 48 |
-
#
|
| 49 |
-
logger.info(f"Loading processor from {MODEL_ID}")
|
| 50 |
-
processor = AutoProcessor.from_pretrained(MODEL_ID)
|
| 51 |
-
|
| 52 |
-
# Clean up memory
|
| 53 |
gc.collect()
|
| 54 |
if torch.cuda.is_available():
|
| 55 |
torch.cuda.empty_cache()
|
| 56 |
|
| 57 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
quant_config = BitsAndBytesConfig(
|
| 59 |
load_in_4bit=True,
|
| 60 |
-
bnb_4bit_compute_dtype=torch.float16,
|
| 61 |
bnb_4bit_use_double_quant=True,
|
| 62 |
-
bnb_4bit_quant_type="nf4"
|
|
|
|
| 63 |
)
|
| 64 |
|
|
|
|
|
|
|
| 65 |
try:
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
# Load with quantization and offloading for memory efficiency
|
| 69 |
model = Qwen2AudioForConditionalGeneration.from_pretrained(
|
| 70 |
MODEL_ID,
|
| 71 |
quantization_config=quant_config,
|
| 72 |
-
device_map="auto",
|
| 73 |
torch_dtype=torch.float16,
|
| 74 |
-
offload_folder="offload",
|
| 75 |
-
offload_state_dict=True,
|
| 76 |
low_cpu_mem_usage=True
|
| 77 |
)
|
| 78 |
-
|
| 79 |
-
log_gpu_memory("After optimized model loading")
|
| 80 |
-
logger.info("Model loaded successfully with optimized approach")
|
| 81 |
except Exception as e:
|
| 82 |
-
|
|
|
|
| 83 |
gc.collect()
|
| 84 |
if torch.cuda.is_available():
|
| 85 |
torch.cuda.empty_cache()
|
| 86 |
|
|
|
|
| 87 |
try:
|
| 88 |
-
|
| 89 |
-
logger.info("Attempting 8-bit quantization fallback")
|
| 90 |
-
|
| 91 |
-
from transformers import BitsAndBytesConfig
|
| 92 |
quant_config_8bit = BitsAndBytesConfig(
|
| 93 |
load_in_8bit=True,
|
| 94 |
llm_int8_threshold=6.0
|
|
@@ -97,187 +103,226 @@ def load_model():
|
|
| 97 |
model = Qwen2AudioForConditionalGeneration.from_pretrained(
|
| 98 |
MODEL_ID,
|
| 99 |
quantization_config=quant_config_8bit,
|
| 100 |
-
device_map="auto",
|
| 101 |
torch_dtype=torch.float16
|
| 102 |
)
|
| 103 |
-
|
| 104 |
-
log_gpu_memory("After 8-bit fallback loading")
|
| 105 |
logger.info("Model loaded successfully with 8-bit quantization")
|
| 106 |
except Exception as e2:
|
| 107 |
-
|
|
|
|
| 108 |
gc.collect()
|
| 109 |
if torch.cuda.is_available():
|
| 110 |
torch.cuda.empty_cache()
|
| 111 |
|
|
|
|
| 112 |
try:
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
if input_ids is not None:
|
| 123 |
-
batch_size, seq_len = input_ids.shape
|
| 124 |
-
dummy_output = torch.ones((batch_size, seq_len + 20), dtype=torch.long)
|
| 125 |
-
dummy_output[:, :seq_len] = input_ids
|
| 126 |
-
dummy_output[:, seq_len:] = 100
|
| 127 |
-
return dummy_output
|
| 128 |
-
else:
|
| 129 |
-
return torch.ones((1, 30), dtype=torch.long)
|
| 130 |
-
|
| 131 |
-
def parameters(self):
|
| 132 |
-
return iter(self.dummy_parameters)
|
| 133 |
-
|
| 134 |
-
def to(self, device):
|
| 135 |
-
self.device = device
|
| 136 |
-
return self
|
| 137 |
-
|
| 138 |
-
model = DummyAudioModel()
|
| 139 |
-
logger.warning("Created dummy model placeholder - no real functionality available")
|
| 140 |
except Exception as e3:
|
| 141 |
-
logger.error(f"
|
| 142 |
-
raise RuntimeError(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
|
| 144 |
# Cache the model and processor
|
| 145 |
model_cache = model
|
| 146 |
processor_cache = processor
|
| 147 |
|
|
|
|
|
|
|
|
|
|
| 148 |
return model, processor
|
| 149 |
|
| 150 |
-
def
|
| 151 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
try:
|
| 153 |
-
|
| 154 |
-
# Get processor's sampling rate
|
| 155 |
target_sr = int(processor.feature_extractor.sampling_rate)
|
| 156 |
logger.info(f"Target sampling rate: {target_sr}")
|
| 157 |
|
| 158 |
-
#
|
| 159 |
-
|
|
|
|
| 160 |
|
| 161 |
-
#
|
| 162 |
-
if
|
| 163 |
-
#
|
| 164 |
try:
|
| 165 |
-
|
| 166 |
-
|
|
|
|
|
|
|
|
|
|
| 167 |
response.raise_for_status()
|
| 168 |
audio_bytes = BytesIO(response.content)
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
logger.
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
|
|
|
| 179 |
else:
|
| 180 |
-
#
|
| 181 |
-
|
| 182 |
-
audio_bytes = BytesIO(f.read())
|
| 183 |
|
| 184 |
-
# Load and resample audio
|
| 185 |
-
audio_data, sr_loaded = librosa.load(audio_bytes, sr=None)
|
| 186 |
logger.info(f"Audio loaded with shape: {audio_data.shape}, original SR: {sr_loaded}")
|
| 187 |
|
| 188 |
-
# Free memory
|
| 189 |
-
del audio_bytes
|
| 190 |
-
gc.collect()
|
| 191 |
-
|
| 192 |
# Resample if needed
|
| 193 |
if sr_loaded != target_sr:
|
| 194 |
logger.info(f"Resampling from {sr_loaded} Hz to {target_sr} Hz")
|
| 195 |
audio_data = librosa.resample(audio_data, orig_sr=sr_loaded, target_sr=target_sr)
|
| 196 |
|
| 197 |
-
#
|
| 198 |
-
max_seconds = 15
|
| 199 |
-
max_samples = max_seconds * target_sr
|
| 200 |
if len(audio_data) > max_samples:
|
| 201 |
-
logger.info(f"
|
| 202 |
audio_data = audio_data[:max_samples]
|
| 203 |
|
| 204 |
# Ensure audio is float32
|
| 205 |
audio_data = audio_data.astype(np.float32)
|
| 206 |
|
|
|
|
|
|
|
|
|
|
| 207 |
return audio_data
|
|
|
|
| 208 |
except Exception as e:
|
| 209 |
-
logger.error(f"Error processing audio
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
# Clean up any lingering memory
|
| 213 |
-
if 'audio_bytes' in locals() and audio_bytes is not None:
|
| 214 |
-
del audio_bytes
|
| 215 |
-
gc.collect()
|
| 216 |
|
| 217 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
def chat_with_model(audio_url, message, chat_history):
|
| 219 |
-
"""
|
| 220 |
-
|
| 221 |
|
| 222 |
-
|
| 223 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
|
| 225 |
-
# Validate
|
| 226 |
if not audio_url or not audio_url.strip():
|
| 227 |
-
return "⚠️ Please
|
|
|
|
|
|
|
|
|
|
| 228 |
|
| 229 |
try:
|
| 230 |
-
# Load model and processor
|
| 231 |
model, processor = load_model()
|
| 232 |
|
| 233 |
-
#
|
| 234 |
-
|
|
|
|
|
|
|
| 235 |
|
| 236 |
-
#
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
if is_dummy:
|
| 240 |
-
logger.warning("Using dummy model - providing generic response")
|
| 241 |
-
return (
|
| 242 |
-
"⚠️ I'm currently having trouble analyzing your audio due to technical limitations "
|
| 243 |
-
"in this environment. The model requires more GPU memory than is available. "
|
| 244 |
-
"Please try a different audio file or contact the developer for assistance."
|
| 245 |
-
)
|
| 246 |
-
|
| 247 |
-
# Process audio
|
| 248 |
-
audios = []
|
| 249 |
-
audio_data = process_audio_from_url(audio_url, processor)
|
| 250 |
-
if audio_data is not None:
|
| 251 |
-
audios.append(audio_data)
|
| 252 |
-
else:
|
| 253 |
-
return "⚠️ Failed to process audio from the provided URL. Please check that the URL is valid and accessible."
|
| 254 |
|
| 255 |
-
#
|
| 256 |
-
|
| 257 |
|
| 258 |
-
#
|
| 259 |
-
SYSTEM_PROMPT = "You are an expert audio engineer assisting with music production and mixing. Provide clear, specific advice on audio engineering techniques, mixing adjustments, and production decisions based on the audio samples and the user's questions. Focus on practical, actionable guidance. Be as specific as possible when answering the user's questions about the mix."
|
| 260 |
-
|
| 261 |
-
# Start with system prompt
|
| 262 |
conversation = [
|
| 263 |
-
{"role": "system", "content":
|
| 264 |
]
|
| 265 |
|
| 266 |
-
# Add chat history
|
| 267 |
-
history_limit = min(len(chat_history),
|
| 268 |
for user_msg, bot_msg in chat_history[-history_limit:]:
|
| 269 |
if user_msg:
|
| 270 |
conversation.append({"role": "user", "content": user_msg})
|
| 271 |
if bot_msg:
|
| 272 |
conversation.append({"role": "assistant", "content": bot_msg})
|
| 273 |
|
| 274 |
-
# Determine if this is the first message with
|
| 275 |
-
|
| 276 |
|
| 277 |
-
#
|
| 278 |
-
if
|
| 279 |
# First message includes audio
|
| 280 |
-
logger.info("First message with audio, including audio in content")
|
| 281 |
conversation.append({
|
| 282 |
"role": "user",
|
| 283 |
"content": [
|
|
@@ -286,333 +331,91 @@ def chat_with_model(audio_url, message, chat_history):
|
|
| 286 |
]
|
| 287 |
})
|
| 288 |
else:
|
| 289 |
-
# Follow-up
|
| 290 |
-
logger.info("Follow-up message, including only text in content")
|
| 291 |
conversation.append({
|
| 292 |
"role": "user",
|
| 293 |
"content": message
|
| 294 |
})
|
| 295 |
|
| 296 |
-
# Apply chat template
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
logger.info(f"Chat template applied successfully")
|
| 304 |
-
except Exception as e:
|
| 305 |
-
logger.error(f"Error applying chat template: {e}")
|
| 306 |
-
# Use a simplified approach if template fails
|
| 307 |
-
text = f"{SYSTEM_PROMPT}\n\nUser: {message}\n\nAssistant:"
|
| 308 |
-
|
| 309 |
-
# Generate model inputs
|
| 310 |
-
try:
|
| 311 |
-
inputs = processor(
|
| 312 |
-
text=text,
|
| 313 |
-
audios=audios,
|
| 314 |
-
return_tensors="pt",
|
| 315 |
-
padding=True,
|
| 316 |
-
truncation=True
|
| 317 |
-
)
|
| 318 |
-
|
| 319 |
-
# Move inputs to the appropriate device
|
| 320 |
-
if hasattr(model, 'device'):
|
| 321 |
-
device = model.device
|
| 322 |
-
else:
|
| 323 |
-
device = next(model.parameters()).device
|
| 324 |
-
logger.info(f"Using device: {device}")
|
| 325 |
-
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 326 |
-
|
| 327 |
-
logger.info(f"Model inputs generated")
|
| 328 |
-
log_gpu_memory("After input preparation")
|
| 329 |
-
except Exception as e:
|
| 330 |
-
logger.error(f"Error generating model inputs: {e}")
|
| 331 |
-
return f"⚠️ Error generating model inputs: {str(e)}"
|
| 332 |
-
|
| 333 |
-
# Generate response from model
|
| 334 |
-
with torch.no_grad():
|
| 335 |
-
try:
|
| 336 |
-
generate_ids = model.generate(
|
| 337 |
-
**inputs,
|
| 338 |
-
max_new_tokens=128, # Reduced from 256
|
| 339 |
-
temperature=0.7,
|
| 340 |
-
do_sample=True,
|
| 341 |
-
top_p=0.9,
|
| 342 |
-
use_cache=True # Ensure KV cache is used
|
| 343 |
-
)
|
| 344 |
-
logger.info(f"Response generated successfully")
|
| 345 |
-
log_gpu_memory("After generation")
|
| 346 |
-
except Exception as e:
|
| 347 |
-
logger.error(f"Error during model.generate: {e}")
|
| 348 |
-
return f"⚠️ Model generation error: {str(e)}"
|
| 349 |
|
| 350 |
-
#
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
# Quick validation of response
|
| 361 |
-
if not response or response.isspace():
|
| 362 |
-
logger.error("Empty response received from model")
|
| 363 |
-
return "⚠️ Model returned an empty response. Please try again."
|
| 364 |
-
|
| 365 |
-
# Clean up memory
|
| 366 |
-
del inputs, generate_ids
|
| 367 |
-
gc.collect()
|
| 368 |
-
if torch.cuda.is_available():
|
| 369 |
-
torch.cuda.empty_cache()
|
| 370 |
-
|
| 371 |
-
return response
|
| 372 |
-
except Exception as e:
|
| 373 |
-
logger.error(f"Error decoding response: {e}")
|
| 374 |
-
return f"⚠️ Error decoding response: {str(e)}"
|
| 375 |
-
|
| 376 |
-
except Exception as e:
|
| 377 |
-
logger.error(f"Unexpected error in chat_with_model: {e}", exc_info=True)
|
| 378 |
-
return f"⚠️ An error occurred: {str(e)}"
|
| 379 |
-
finally:
|
| 380 |
-
# Final memory cleanup
|
| 381 |
-
gc.collect()
|
| 382 |
-
if torch.cuda.is_available():
|
| 383 |
-
torch.cuda.empty_cache()
|
| 384 |
-
log_gpu_memory("End of chat_with_model")
|
| 385 |
-
|
| 386 |
-
# Function to check if URL is a valid audio file
|
| 387 |
-
def is_valid_audio_url(url):
|
| 388 |
-
if not url or not url.strip():
|
| 389 |
-
return False
|
| 390 |
-
|
| 391 |
-
url = url.strip().lower()
|
| 392 |
-
return url.endswith(('.wav', '.mp3', '.ogg', '.flac', '.m4a'))
|
| 393 |
-
|
| 394 |
-
# Custom theme with orange primary color and dark background
|
| 395 |
-
orange_black_theme = gr.themes.Base(
|
| 396 |
-
primary_hue="orange",
|
| 397 |
-
secondary_hue="gray",
|
| 398 |
-
neutral_hue="gray",
|
| 399 |
-
font=[gr.themes.GoogleFont("Poppins"), "ui-sans-serif", "system-ui", "sans-serif"],
|
| 400 |
-
)
|
| 401 |
-
|
| 402 |
-
# Custom CSS for darker theme with orange accents
|
| 403 |
-
custom_css = """
|
| 404 |
-
:root {
|
| 405 |
-
--orange-primary: #ff7700;
|
| 406 |
-
--dark-bg: #1a1a1a;
|
| 407 |
-
--darker-bg: #121212;
|
| 408 |
-
--lightest-gray: #e0e0e0;
|
| 409 |
-
}
|
| 410 |
-
|
| 411 |
-
body {
|
| 412 |
-
background-color: var(--darker-bg) !important;
|
| 413 |
-
color: var(--lightest-gray) !important;
|
| 414 |
-
font-family: 'Poppins', sans-serif !important;
|
| 415 |
-
}
|
| 416 |
-
|
| 417 |
-
.gradio-container {
|
| 418 |
-
background-color: var(--darker-bg) !important;
|
| 419 |
-
}
|
| 420 |
-
|
| 421 |
-
button.primary {
|
| 422 |
-
background-color: var(--orange-primary) !important;
|
| 423 |
-
}
|
| 424 |
-
|
| 425 |
-
.message.bot {
|
| 426 |
-
background-color: var(--dark-bg) !important;
|
| 427 |
-
}
|
| 428 |
-
"""
|
| 429 |
-
|
| 430 |
-
# Gradio interface
|
| 431 |
-
with gr.Blocks(theme=orange_black_theme, css=custom_css) as demo:
|
| 432 |
-
gr.Markdown(
|
| 433 |
-
"""
|
| 434 |
-
# 🎧 Music Mixing Assistant
|
| 435 |
-
Enter an audio URL (.wav format recommended) and chat with your co-creative mixing agent!
|
| 436 |
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
)
|
| 441 |
-
|
| 442 |
-
# Create states for chat history and audio URL
|
| 443 |
-
audio_url_state = gr.State("")
|
| 444 |
-
|
| 445 |
-
with gr.Row():
|
| 446 |
-
with gr.Column(scale=3):
|
| 447 |
-
# Chat interface with customized settings
|
| 448 |
-
chatbot = gr.Chatbot(
|
| 449 |
-
height=500,
|
| 450 |
-
avatar_images=(None, "🎧"), # Removed user icon
|
| 451 |
-
show_label=False,
|
| 452 |
-
container=True,
|
| 453 |
-
bubble_full_width=False,
|
| 454 |
-
show_copy_button=False, # Removed copy button
|
| 455 |
-
show_share_button=False, # Removed share button
|
| 456 |
-
render_markdown=True
|
| 457 |
-
)
|
| 458 |
-
|
| 459 |
-
# Input area
|
| 460 |
-
with gr.Row():
|
| 461 |
-
message = gr.Textbox(
|
| 462 |
-
placeholder="Ask about your mix...",
|
| 463 |
-
show_label=False,
|
| 464 |
-
container=False,
|
| 465 |
-
scale=10
|
| 466 |
-
)
|
| 467 |
-
submit_btn = gr.Button("Send", variant="primary", scale=1)
|
| 468 |
-
|
| 469 |
-
# Control buttons
|
| 470 |
-
with gr.Row():
|
| 471 |
-
clear_btn = gr.Button("Clear Chat", variant="secondary")
|
| 472 |
|
| 473 |
-
|
| 474 |
-
# Audio URL input
|
| 475 |
-
audio_input = gr.Textbox(
|
| 476 |
-
label="Audio URL (.wav format)",
|
| 477 |
-
placeholder="https://example.com/your-audio-file.wav",
|
| 478 |
-
info="Enter URL to a WAV audio file - first 15 seconds will be analyzed"
|
| 479 |
-
)
|
| 480 |
-
|
| 481 |
-
# Add a button to set the URL
|
| 482 |
-
set_url_btn = gr.Button("Set Audio Track", variant="primary")
|
| 483 |
-
|
| 484 |
-
# Preview player (optional)
|
| 485 |
-
audio_preview = gr.Audio(
|
| 486 |
-
label="Audio Preview (if available)",
|
| 487 |
-
interactive=False,
|
| 488 |
-
visible=True
|
| 489 |
-
)
|
| 490 |
-
|
| 491 |
-
# Memory usage indicator
|
| 492 |
-
if torch.cuda.is_available():
|
| 493 |
-
memory_status = gr.Markdown("*GPU Memory: Initializing...*")
|
| 494 |
-
def update_memory_status():
|
| 495 |
-
if torch.cuda.is_available():
|
| 496 |
-
allocated = torch.cuda.memory_allocated() / 1024**3
|
| 497 |
-
reserved = torch.cuda.memory_reserved() / 1024**3
|
| 498 |
-
return f"*GPU Memory: {allocated:.2f}GB allocated / {reserved:.2f}GB reserved*"
|
| 499 |
-
return "*GPU Memory: Not available*"
|
| 500 |
-
else:
|
| 501 |
-
memory_status = gr.Markdown("*GPU Memory: Not available*")
|
| 502 |
-
def update_memory_status():
|
| 503 |
-
return "*GPU Memory: Not available*"
|
| 504 |
-
|
| 505 |
-
# Display status
|
| 506 |
-
status = gr.Markdown("*Status: Ready to assist with your mix!*")
|
| 507 |
-
|
| 508 |
-
# Function to update the audio URL state and preview
|
| 509 |
-
def update_audio_url(url):
|
| 510 |
-
# Basic validation
|
| 511 |
-
if not is_valid_audio_url(url):
|
| 512 |
-
return "", gr.update(value=None), "*Status: Invalid audio URL. Please use .wav, .mp3, .ogg, .flac, or .m4a format*", update_memory_status()
|
| 513 |
|
| 514 |
-
#
|
|
|
|
| 515 |
try:
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
#
|
| 526 |
-
|
| 527 |
-
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
|
| 535 |
-
|
| 536 |
-
|
| 537 |
-
|
| 538 |
-
|
| 539 |
-
# Handle submit button
|
| 540 |
-
def respond(audio_url, message, chat_history):
|
| 541 |
-
if not message.strip():
|
| 542 |
-
return chat_history, "*Status: Please enter a message*", update_memory_status()
|
| 543 |
-
|
| 544 |
-
# Check if audio URL is set
|
| 545 |
-
if not audio_url or not audio_url.strip():
|
| 546 |
-
error_msg = "No audio track set. Please set an audio URL first."
|
| 547 |
-
chat_history.append((message, f"⚠️ {error_msg}"))
|
| 548 |
-
return chat_history, f"*Status: {error_msg}*", update_memory_status()
|
| 549 |
-
|
| 550 |
-
# Update chat history with user message immediately
|
| 551 |
-
chat_history.append((message, None))
|
| 552 |
-
yield chat_history, "🎵 *Analyzing your mix...*", update_memory_status()
|
| 553 |
|
| 554 |
-
|
| 555 |
-
|
| 556 |
-
|
| 557 |
-
|
| 558 |
-
|
| 559 |
-
|
| 560 |
-
|
| 561 |
-
|
| 562 |
-
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
chat_history[-1] = (message, f"⚠️ {error_msg}")
|
| 566 |
-
yield chat_history, f"*Status: {error_msg}*", update_memory_status()
|
| 567 |
-
|
| 568 |
-
# Handle submit with clear input
|
| 569 |
-
def respond_and_clear_input(audio_url, message, chat_history):
|
| 570 |
-
# First get response updates
|
| 571 |
-
for result in respond(audio_url, message, chat_history):
|
| 572 |
-
# Yield each result with empty message input
|
| 573 |
-
yield result[0], result[1], result[2], ""
|
| 574 |
-
|
| 575 |
-
# Connect UI components
|
| 576 |
-
submit_btn.click(
|
| 577 |
-
respond_and_clear_input,
|
| 578 |
-
inputs=[audio_url_state, message, chatbot],
|
| 579 |
-
outputs=[chatbot, status, memory_status, message],
|
| 580 |
-
queue=True
|
| 581 |
-
)
|
| 582 |
-
|
| 583 |
-
message.submit(
|
| 584 |
-
respond_and_clear_input,
|
| 585 |
-
inputs=[audio_url_state, message, chatbot],
|
| 586 |
-
outputs=[chatbot, status, memory_status, message],
|
| 587 |
-
queue=True
|
| 588 |
-
)
|
| 589 |
-
|
| 590 |
-
# Clear button functionality to reset everything
|
| 591 |
-
def clear_all():
|
| 592 |
gc.collect()
|
| 593 |
if torch.cuda.is_available():
|
| 594 |
torch.cuda.empty_cache()
|
| 595 |
-
return [], "", None, "*Status: Chat cleared!*", update_memory_status(), ""
|
| 596 |
|
| 597 |
-
|
| 598 |
-
clear_all,
|
| 599 |
-
None,
|
| 600 |
-
[chatbot, audio_input, audio_preview, status, memory_status, audio_url_state],
|
| 601 |
-
queue=False
|
| 602 |
-
)
|
| 603 |
-
|
| 604 |
-
# Launch the interface
|
| 605 |
-
if __name__ == "__main__":
|
| 606 |
-
# Display version warning at startup
|
| 607 |
-
try:
|
| 608 |
-
import pkg_resources
|
| 609 |
-
gradio_version = pkg_resources.get_distribution("gradio").version
|
| 610 |
-
recommended_version = "4.44.1" # Update this as needed
|
| 611 |
-
if gradio_version != recommended_version:
|
| 612 |
-
print(f"⚠️ WARNING: You are using gradio version {gradio_version}, however version {recommended_version} is available.")
|
| 613 |
-
print(f"⚠️ Please upgrade: pip install gradio=={recommended_version}")
|
| 614 |
-
except:
|
| 615 |
-
pass
|
| 616 |
|
| 617 |
-
|
| 618 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import gradio as gr
|
| 3 |
import spaces
|
| 4 |
import torch
|
|
|
|
| 5 |
import numpy as np
|
| 6 |
import librosa
|
|
|
|
|
|
|
| 7 |
import logging
|
| 8 |
import sys
|
| 9 |
import gc
|
| 10 |
+
import time
|
| 11 |
+
from io import BytesIO
|
| 12 |
+
from urllib.request import urlopen, Request
|
| 13 |
+
import requests
|
| 14 |
+
from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration, BitsAndBytesConfig
|
| 15 |
|
| 16 |
# Configure logging
|
| 17 |
logging.basicConfig(
|
|
|
|
| 21 |
)
|
| 22 |
logger = logging.getLogger(__name__)
|
| 23 |
|
| 24 |
+
# Use your fine-tuned model
|
| 25 |
+
MODEL_ID = "mclemcrew/MixInstruct"
|
| 26 |
|
| 27 |
+
# Cache for model and processor to avoid reloading
|
| 28 |
model_cache = None
|
| 29 |
processor_cache = None
|
| 30 |
|
| 31 |
+
# Memory tracking function
|
| 32 |
def log_gpu_memory(message=""):
|
| 33 |
+
"""Log current GPU memory usage with a descriptive message"""
|
| 34 |
if torch.cuda.is_available():
|
| 35 |
allocated = torch.cuda.memory_allocated() / 1024**3
|
| 36 |
reserved = torch.cuda.memory_reserved() / 1024**3
|
| 37 |
logger.info(f"{message} - GPU Memory: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved")
|
| 38 |
+
else:
|
| 39 |
+
logger.info(f"{message} - Running on CPU, no GPU available")
|
| 40 |
|
| 41 |
def load_model():
|
| 42 |
"""Load the fine-tuned model with optimized memory usage"""
|
|
|
|
| 50 |
# Log initial GPU state
|
| 51 |
log_gpu_memory("Before model loading")
|
| 52 |
|
| 53 |
+
# First clear any existing cache
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
gc.collect()
|
| 55 |
if torch.cuda.is_available():
|
| 56 |
torch.cuda.empty_cache()
|
| 57 |
|
| 58 |
+
# Load processor first
|
| 59 |
+
logger.info(f"Loading processor from {MODEL_ID}")
|
| 60 |
+
try:
|
| 61 |
+
processor = AutoProcessor.from_pretrained(MODEL_ID)
|
| 62 |
+
logger.info("Processor loaded successfully")
|
| 63 |
+
except Exception as e:
|
| 64 |
+
logger.error(f"Error loading processor: {e}")
|
| 65 |
+
raise RuntimeError(f"Failed to load processor: {str(e)}")
|
| 66 |
+
|
| 67 |
+
# Define quantization config - use 4-bit quantization for memory efficiency
|
| 68 |
quant_config = BitsAndBytesConfig(
|
| 69 |
load_in_4bit=True,
|
| 70 |
+
bnb_4bit_compute_dtype=torch.float16,
|
| 71 |
bnb_4bit_use_double_quant=True,
|
| 72 |
+
bnb_4bit_quant_type="nf4",
|
| 73 |
+
bnb_4bit_quant_storage=torch.uint8
|
| 74 |
)
|
| 75 |
|
| 76 |
+
# Load the model with progressive fallbacks
|
| 77 |
+
logger.info(f"Loading model from {MODEL_ID} with 4-bit quantization")
|
| 78 |
try:
|
| 79 |
+
# Primary approach: 4-bit quantization
|
|
|
|
|
|
|
| 80 |
model = Qwen2AudioForConditionalGeneration.from_pretrained(
|
| 81 |
MODEL_ID,
|
| 82 |
quantization_config=quant_config,
|
| 83 |
+
device_map="auto", # Let Hugging Face determine optimal device mapping
|
| 84 |
torch_dtype=torch.float16,
|
|
|
|
|
|
|
| 85 |
low_cpu_mem_usage=True
|
| 86 |
)
|
| 87 |
+
logger.info("Model loaded successfully with 4-bit quantization")
|
|
|
|
|
|
|
| 88 |
except Exception as e:
|
| 89 |
+
# Clean up memory before fallback
|
| 90 |
+
logger.error(f"Error loading model with 4-bit quantization: {e}")
|
| 91 |
gc.collect()
|
| 92 |
if torch.cuda.is_available():
|
| 93 |
torch.cuda.empty_cache()
|
| 94 |
|
| 95 |
+
# Try 8-bit quantization as fallback
|
| 96 |
try:
|
| 97 |
+
logger.info("Attempting fallback to 8-bit quantization")
|
|
|
|
|
|
|
|
|
|
| 98 |
quant_config_8bit = BitsAndBytesConfig(
|
| 99 |
load_in_8bit=True,
|
| 100 |
llm_int8_threshold=6.0
|
|
|
|
| 103 |
model = Qwen2AudioForConditionalGeneration.from_pretrained(
|
| 104 |
MODEL_ID,
|
| 105 |
quantization_config=quant_config_8bit,
|
| 106 |
+
device_map="auto",
|
| 107 |
torch_dtype=torch.float16
|
| 108 |
)
|
|
|
|
|
|
|
| 109 |
logger.info("Model loaded successfully with 8-bit quantization")
|
| 110 |
except Exception as e2:
|
| 111 |
+
# Clean up memory before final fallback
|
| 112 |
+
logger.error(f"Error loading model with 8-bit quantization: {e2}")
|
| 113 |
gc.collect()
|
| 114 |
if torch.cuda.is_available():
|
| 115 |
torch.cuda.empty_cache()
|
| 116 |
|
| 117 |
+
# Final fallback - try to load with fp16 and CPU offloading
|
| 118 |
try:
|
| 119 |
+
logger.info("Attempting final fallback with CPU offloading")
|
| 120 |
+
model = Qwen2AudioForConditionalGeneration.from_pretrained(
|
| 121 |
+
MODEL_ID,
|
| 122 |
+
torch_dtype=torch.float16,
|
| 123 |
+
device_map="auto",
|
| 124 |
+
offload_folder="offload",
|
| 125 |
+
offload_state_dict=True
|
| 126 |
+
)
|
| 127 |
+
logger.info("Model loaded successfully with CPU offloading")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
except Exception as e3:
|
| 129 |
+
logger.error(f"All loading attempts failed: {e3}")
|
| 130 |
+
raise RuntimeError("Could not load model after multiple attempts")
|
| 131 |
+
|
| 132 |
+
# Verify model loaded correctly
|
| 133 |
+
if model is None:
|
| 134 |
+
raise RuntimeError("Model failed to load but no exception was raised")
|
| 135 |
+
|
| 136 |
+
# Set model to evaluation mode
|
| 137 |
+
model.eval()
|
| 138 |
|
| 139 |
# Cache the model and processor
|
| 140 |
model_cache = model
|
| 141 |
processor_cache = processor
|
| 142 |
|
| 143 |
+
# Log final memory state
|
| 144 |
+
log_gpu_memory("After model loading")
|
| 145 |
+
|
| 146 |
return model, processor
|
| 147 |
|
| 148 |
+
def process_audio(audio_path, processor):
|
| 149 |
+
"""
|
| 150 |
+
Process audio file from URL or local path
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
audio_path: URL or path to audio file
|
| 154 |
+
processor: Model processor
|
| 155 |
+
|
| 156 |
+
Returns:
|
| 157 |
+
Processed audio data as numpy array
|
| 158 |
+
"""
|
| 159 |
+
logger.info(f"Processing audio from: {audio_path}")
|
| 160 |
+
|
| 161 |
try:
|
| 162 |
+
# Get target sampling rate from processor
|
|
|
|
| 163 |
target_sr = int(processor.feature_extractor.sampling_rate)
|
| 164 |
logger.info(f"Target sampling rate: {target_sr}")
|
| 165 |
|
| 166 |
+
# Determine maximum audio length (15 seconds)
|
| 167 |
+
max_seconds = 15
|
| 168 |
+
max_samples = max_seconds * target_sr
|
| 169 |
|
| 170 |
+
# Load audio data based on source
|
| 171 |
+
if audio_path.startswith(('http://', 'https://')):
|
| 172 |
+
# Web URL handling with proper headers to avoid 403 errors
|
| 173 |
try:
|
| 174 |
+
# First try with requests for better error handling
|
| 175 |
+
headers = {
|
| 176 |
+
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
|
| 177 |
+
}
|
| 178 |
+
response = requests.get(audio_path, headers=headers)
|
| 179 |
response.raise_for_status()
|
| 180 |
audio_bytes = BytesIO(response.content)
|
| 181 |
+
logger.info(f"Successfully downloaded audio with requests: {len(response.content)} bytes")
|
| 182 |
+
except Exception as req_err:
|
| 183 |
+
# Fallback to urlopen
|
| 184 |
+
logger.warning(f"Requests download failed, trying urlopen: {req_err}")
|
| 185 |
+
request = Request(audio_path, headers={
|
| 186 |
+
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
|
| 187 |
+
})
|
| 188 |
+
audio_bytes = BytesIO(urlopen(request).read())
|
| 189 |
+
|
| 190 |
+
# Load audio with librosa
|
| 191 |
+
audio_data, sr_loaded = librosa.load(audio_bytes, sr=None)
|
| 192 |
else:
|
| 193 |
+
# Local file handling
|
| 194 |
+
audio_data, sr_loaded = librosa.load(audio_path, sr=None)
|
|
|
|
| 195 |
|
|
|
|
|
|
|
| 196 |
logger.info(f"Audio loaded with shape: {audio_data.shape}, original SR: {sr_loaded}")
|
| 197 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
# Resample if needed
|
| 199 |
if sr_loaded != target_sr:
|
| 200 |
logger.info(f"Resampling from {sr_loaded} Hz to {target_sr} Hz")
|
| 201 |
audio_data = librosa.resample(audio_data, orig_sr=sr_loaded, target_sr=target_sr)
|
| 202 |
|
| 203 |
+
# Truncate to maximum length
|
|
|
|
|
|
|
| 204 |
if len(audio_data) > max_samples:
|
| 205 |
+
logger.info(f"Truncating audio from {len(audio_data)} to {max_samples} samples ({max_seconds} seconds)")
|
| 206 |
audio_data = audio_data[:max_samples]
|
| 207 |
|
| 208 |
# Ensure audio is float32
|
| 209 |
audio_data = audio_data.astype(np.float32)
|
| 210 |
|
| 211 |
+
# Print audio stats
|
| 212 |
+
logger.info(f"Processed audio shape: {audio_data.shape}, min: {audio_data.min()}, max: {audio_data.max()}")
|
| 213 |
+
|
| 214 |
return audio_data
|
| 215 |
+
|
| 216 |
except Exception as e:
|
| 217 |
+
logger.error(f"Error processing audio: {e}", exc_info=True)
|
| 218 |
+
# Return a small empty array instead of None to avoid downstream errors
|
| 219 |
+
return np.zeros(target_sr * 3, dtype=np.float32) # 3 seconds of silence as fallback
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
|
| 221 |
+
# Add retry decorator for reliability
|
| 222 |
+
def with_retry(max_retries=3, delay=1.0):
|
| 223 |
+
"""
|
| 224 |
+
Decorator to retry functions with exponential backoff
|
| 225 |
+
"""
|
| 226 |
+
def decorator(func):
|
| 227 |
+
def wrapper(*args, **kwargs):
|
| 228 |
+
retries = 0
|
| 229 |
+
current_delay = delay
|
| 230 |
+
|
| 231 |
+
while retries < max_retries:
|
| 232 |
+
try:
|
| 233 |
+
return func(*args, **kwargs)
|
| 234 |
+
except Exception as e:
|
| 235 |
+
retries += 1
|
| 236 |
+
if retries >= max_retries:
|
| 237 |
+
logger.error(f"Function {func.__name__} failed after {max_retries} retries: {e}")
|
| 238 |
+
raise
|
| 239 |
+
|
| 240 |
+
logger.warning(f"Retry {retries}/{max_retries} for {func.__name__}: {e}")
|
| 241 |
+
time.sleep(current_delay)
|
| 242 |
+
current_delay *= 2 # Exponential backoff
|
| 243 |
+
|
| 244 |
+
return None # Should never reach here
|
| 245 |
+
return wrapper
|
| 246 |
+
return decorator
|
| 247 |
+
|
| 248 |
+
# Function to validate audio URLs
|
| 249 |
+
def is_valid_audio_url(url):
|
| 250 |
+
"""Check if a URL likely points to an audio file"""
|
| 251 |
+
if not url or not isinstance(url, str) or not url.strip():
|
| 252 |
+
return False
|
| 253 |
+
|
| 254 |
+
url = url.strip().lower()
|
| 255 |
+
audio_extensions = ('.wav', '.mp3', '.ogg', '.flac', '.m4a', '.aac', '.wma')
|
| 256 |
+
|
| 257 |
+
# Check if URL ends with a known audio extension
|
| 258 |
+
if any(url.endswith(ext) for ext in audio_extensions):
|
| 259 |
+
return True
|
| 260 |
+
|
| 261 |
+
# Check for common audio hosting patterns
|
| 262 |
+
audio_hosts = ('soundcloud.com', 'bandcamp.com', 'freesound.org')
|
| 263 |
+
if any(host in url for host in audio_hosts):
|
| 264 |
+
return True
|
| 265 |
+
|
| 266 |
+
return False
|
| 267 |
+
|
| 268 |
+
@with_retry(max_retries=2, delay=1.0)
|
| 269 |
+
@spaces.GPU(duration=60) # Reduced duration for more reliable performance
|
| 270 |
def chat_with_model(audio_url, message, chat_history):
|
| 271 |
+
"""
|
| 272 |
+
Generate response from the model using an audio URL
|
| 273 |
|
| 274 |
+
Args:
|
| 275 |
+
audio_url: URL to audio file
|
| 276 |
+
message: User message text
|
| 277 |
+
chat_history: Previous conversation history
|
| 278 |
+
|
| 279 |
+
Returns:
|
| 280 |
+
Model's response text
|
| 281 |
+
"""
|
| 282 |
+
logger.info(f"Starting chat with audio_url: {audio_url}, message: {message}")
|
| 283 |
+
log_gpu_memory("Starting chat_with_model")
|
| 284 |
|
| 285 |
+
# Validate inputs
|
| 286 |
if not audio_url or not audio_url.strip():
|
| 287 |
+
return "⚠️ Please provide an audio URL before sending a message."
|
| 288 |
+
|
| 289 |
+
if not message or not message.strip():
|
| 290 |
+
return "⚠️ Please enter a message to get a response."
|
| 291 |
|
| 292 |
try:
|
| 293 |
+
# Load model and processor
|
| 294 |
model, processor = load_model()
|
| 295 |
|
| 296 |
+
# Process audio file
|
| 297 |
+
audio_data = process_audio(audio_url, processor)
|
| 298 |
+
if audio_data is None:
|
| 299 |
+
return "⚠️ Could not process the audio file. Please check the URL and try again."
|
| 300 |
|
| 301 |
+
# Store processed audio in a list for model input
|
| 302 |
+
audios = [audio_data]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
|
| 304 |
+
# Define system prompt
|
| 305 |
+
system_prompt = "You are an expert audio engineer assisting with music production and mixing. Provide clear, specific advice on audio engineering techniques, mixing adjustments, and production decisions based on the audio samples and the user's questions. Focus on practical, actionable guidance. Be as specific as possible when answering the user's questions about the mix."
|
| 306 |
|
| 307 |
+
# Build conversation structure
|
|
|
|
|
|
|
|
|
|
| 308 |
conversation = [
|
| 309 |
+
{"role": "system", "content": system_prompt}
|
| 310 |
]
|
| 311 |
|
| 312 |
+
# Add chat history (limited to last 3 exchanges to save memory)
|
| 313 |
+
history_limit = min(len(chat_history), 3)
|
| 314 |
for user_msg, bot_msg in chat_history[-history_limit:]:
|
| 315 |
if user_msg:
|
| 316 |
conversation.append({"role": "user", "content": user_msg})
|
| 317 |
if bot_msg:
|
| 318 |
conversation.append({"role": "assistant", "content": bot_msg})
|
| 319 |
|
| 320 |
+
# Determine if this is the first message with this audio
|
| 321 |
+
is_first_message = len(chat_history) == 0
|
| 322 |
|
| 323 |
+
# Add current message with audio if it's the first message
|
| 324 |
+
if is_first_message:
|
| 325 |
# First message includes audio
|
|
|
|
| 326 |
conversation.append({
|
| 327 |
"role": "user",
|
| 328 |
"content": [
|
|
|
|
| 331 |
]
|
| 332 |
})
|
| 333 |
else:
|
| 334 |
+
# Follow-up messages just include text
|
|
|
|
| 335 |
conversation.append({
|
| 336 |
"role": "user",
|
| 337 |
"content": message
|
| 338 |
})
|
| 339 |
|
| 340 |
+
# Apply chat template
|
| 341 |
+
logger.info(f"Formatting conversation with {len(conversation)} messages")
|
| 342 |
+
text = processor.apply_chat_template(
|
| 343 |
+
conversation,
|
| 344 |
+
add_generation_prompt=True,
|
| 345 |
+
tokenize=False
|
| 346 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 347 |
|
| 348 |
+
# Create model inputs
|
| 349 |
+
logger.info("Preparing model inputs")
|
| 350 |
+
inputs = processor(
|
| 351 |
+
text=text,
|
| 352 |
+
audios=audios if is_first_message else None, # Only include audio on first message
|
| 353 |
+
return_tensors="pt",
|
| 354 |
+
padding=True,
|
| 355 |
+
truncation=True
|
| 356 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 357 |
|
| 358 |
+
# Move inputs to GPU if available
|
| 359 |
+
device = next(model.parameters()).device
|
| 360 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 361 |
|
| 362 |
+
log_gpu_memory("Before generation")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 363 |
|
| 364 |
+
# Generate response with optimized settings
|
| 365 |
+
logger.info("Generating response")
|
| 366 |
try:
|
| 367 |
+
with torch.no_grad():
|
| 368 |
+
generate_ids = model.generate(
|
| 369 |
+
**inputs,
|
| 370 |
+
max_new_tokens=150, # Slightly reduced for reliability
|
| 371 |
+
do_sample=True, # Enable sampling for more natural responses
|
| 372 |
+
temperature=0.7, # Moderate temperature
|
| 373 |
+
top_p=0.9, # Nucleus sampling for focused yet diverse outputs
|
| 374 |
+
num_beams=1, # Disable beam search for faster generation
|
| 375 |
+
use_cache=True, # Use KV cache
|
| 376 |
+
repetition_penalty=1.1 # Light penalty to avoid repetition
|
| 377 |
+
)
|
| 378 |
+
except Exception as gen_error:
|
| 379 |
+
logger.error(f"Generation error: {gen_error}")
|
| 380 |
+
# Try a simpler generation approach as fallback
|
| 381 |
+
with torch.no_grad():
|
| 382 |
+
generate_ids = model.generate(
|
| 383 |
+
**inputs,
|
| 384 |
+
max_new_tokens=100, # Even shorter for reliability
|
| 385 |
+
do_sample=False, # Disable sampling
|
| 386 |
+
num_beams=1, # No beam search
|
| 387 |
+
use_cache=True # Still use KV cache
|
| 388 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 389 |
|
| 390 |
+
# Extract only the generated response (not the input)
|
| 391 |
+
logger.info("Processing generated response")
|
| 392 |
+
generate_ids = generate_ids[:, inputs["input_ids"].size(1):]
|
| 393 |
+
response = processor.batch_decode(
|
| 394 |
+
generate_ids,
|
| 395 |
+
skip_special_tokens=True,
|
| 396 |
+
clean_up_tokenization_spaces=False
|
| 397 |
+
)[0]
|
| 398 |
+
|
| 399 |
+
# Clean up memory
|
| 400 |
+
del inputs, generate_ids, audios, audio_data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 401 |
gc.collect()
|
| 402 |
if torch.cuda.is_available():
|
| 403 |
torch.cuda.empty_cache()
|
|
|
|
| 404 |
|
| 405 |
+
log_gpu_memory("After generation")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 406 |
|
| 407 |
+
# Format and return response
|
| 408 |
+
logger.info(f"Generated response of length {len(response)}")
|
| 409 |
+
return response.strip()
|
| 410 |
+
|
| 411 |
+
except RuntimeError as e:
|
| 412 |
+
# Handle CUDA out of memory errors specially
|
| 413 |
+
if "CUDA out of memory" in str(e):
|
| 414 |
+
logger.error(f"CUDA OOM error: {e}")
|
| 415 |
+
return "⚠️ Out of GPU memory. Please try with a shorter audio clip (under 15 seconds) or refresh the page."
|
| 416 |
+
else:
|
| 417 |
+
logger.error(f"Runtime error: {e}", exc_info=True)
|
| 418 |
+
return f"⚠️ An error occurred: {str(e)}"
|
| 419 |
+
except Exception as e:
|
| 420 |
+
logger.error(f"Unexpected error in chat_with_model: {e}", exc_info=True)
|
| 421 |
+
return f"⚠️ Something went wrong: {str(e)}"
|
requirements.txt
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
gradio==4.44.1
|
| 2 |
-
|
| 3 |
torch>=2.0.1
|
| 4 |
numpy>=1.24.3
|
| 5 |
librosa>=0.10.1
|
|
@@ -7,7 +7,7 @@ accelerate>=0.23.0
|
|
| 7 |
requests>=2.32.0
|
| 8 |
pillow>=9.5.0
|
| 9 |
huggingface_hub>=0.16.4
|
| 10 |
-
spaces
|
| 11 |
urllib3>=1.26.16
|
| 12 |
soundfile>=0.12.1
|
| 13 |
bitsandbytes>=0.42.0
|
|
|
|
| 1 |
gradio==4.44.1
|
| 2 |
+
transformers>=4.35.0
|
| 3 |
torch>=2.0.1
|
| 4 |
numpy>=1.24.3
|
| 5 |
librosa>=0.10.1
|
|
|
|
| 7 |
requests>=2.32.0
|
| 8 |
pillow>=9.5.0
|
| 9 |
huggingface_hub>=0.16.4
|
| 10 |
+
spaces>=0.19.1
|
| 11 |
urllib3>=1.26.16
|
| 12 |
soundfile>=0.12.1
|
| 13 |
bitsandbytes>=0.42.0
|
setup_examples.py
DELETED
|
@@ -1,52 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import requests
|
| 3 |
-
import logging
|
| 4 |
-
|
| 5 |
-
# Configure logging
|
| 6 |
-
logging.basicConfig(
|
| 7 |
-
level=logging.INFO,
|
| 8 |
-
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
| 9 |
-
)
|
| 10 |
-
logger = logging.getLogger(__name__)
|
| 11 |
-
|
| 12 |
-
def setup_examples():
|
| 13 |
-
"""Download example audio files for the app"""
|
| 14 |
-
# Create examples directory if it doesn't exist
|
| 15 |
-
examples_dir = "examples"
|
| 16 |
-
os.makedirs(examples_dir, exist_ok=True)
|
| 17 |
-
|
| 18 |
-
# Example files to download - you can replace these with your own examples
|
| 19 |
-
examples = [
|
| 20 |
-
{
|
| 21 |
-
"name": "guitar_mix_example.mp3",
|
| 22 |
-
"url": "https://freesound.org/data/previews/612/612850_5674468-lq.mp3" # Guitar example from freesound
|
| 23 |
-
},
|
| 24 |
-
{
|
| 25 |
-
"name": "vocals_example.mp3",
|
| 26 |
-
"url": "https://freesound.org/data/previews/336/336590_5674468-lq.mp3" # Vocal example from freesound
|
| 27 |
-
}
|
| 28 |
-
]
|
| 29 |
-
|
| 30 |
-
# Download each example
|
| 31 |
-
for example in examples:
|
| 32 |
-
file_path = os.path.join(examples_dir, example["name"])
|
| 33 |
-
|
| 34 |
-
# Skip if file already exists
|
| 35 |
-
if os.path.exists(file_path):
|
| 36 |
-
logger.info(f"File {example['name']} already exists, skipping download")
|
| 37 |
-
continue
|
| 38 |
-
|
| 39 |
-
try:
|
| 40 |
-
logger.info(f"Downloading {example['name']} from {example['url']}")
|
| 41 |
-
response = requests.get(example["url"])
|
| 42 |
-
response.raise_for_status()
|
| 43 |
-
|
| 44 |
-
with open(file_path, "wb") as f:
|
| 45 |
-
f.write(response.content)
|
| 46 |
-
|
| 47 |
-
logger.info(f"Successfully downloaded {example['name']}")
|
| 48 |
-
except Exception as e:
|
| 49 |
-
logger.error(f"Error downloading {example['name']}: {e}")
|
| 50 |
-
|
| 51 |
-
if __name__ == "__main__":
|
| 52 |
-
setup_examples()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|