CoMix-Demo / app.py
mclemcrew's picture
update to 30 seconds
c3bc466
import os
import sys
import logging
import gradio as gr
import torch
import numpy as np
import librosa
import requests
from io import BytesIO
from urllib.request import urlopen, Request
import gc
from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger(__name__)
# Use the correct model ID
MODEL_ID = "mclemcrew/Qwen-Audio-Mix-Instruct"
# Global variables for model and processor
model = None
processor = None
def log_gpu_memory(message=""):
"""Log GPU memory usage"""
if torch.cuda.is_available():
allocated = torch.cuda.memory_allocated() / 1024**3
reserved = torch.cuda.memory_reserved() / 1024**3
logger.info(f"{message} - GPU: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved")
def load_model():
"""Load the model and processor with better error handling"""
global model, processor
# Return if already loaded
if model is not None and processor is not None:
return model, processor
try:
# First clear CUDA cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Load processor first
logger.info(f"Loading processor from {MODEL_ID}")
processor = AutoProcessor.from_pretrained(MODEL_ID)
logger.info("Processor loaded successfully")
# Force disable bitsandbytes integration
os.environ["DISABLE_BITSANDBYTES_CUDA_SETUP"] = "TRUE"
if torch.cuda.is_available():
gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
logger.info(f"GPU memory: {gpu_memory:.2f} GB")
# Load directly with FP16 but without 8-bit quantization
logger.info("Loading model with FP16 precision")
model = Qwen2AudioForConditionalGeneration.from_pretrained(
MODEL_ID,
torch_dtype=torch.float16,
device_map="auto",
quantization_config=None, # Explicitly disable quantization
low_cpu_mem_usage=True
)
else:
# Load on CPU if no GPU
logger.info("Loading model on CPU")
model = Qwen2AudioForConditionalGeneration.from_pretrained(MODEL_ID)
model.eval()
log_gpu_memory("After model loading")
return model, processor
except Exception as e:
logger.error(f"Error loading model or processor: {e}")
raise
def process_audio_file(audio_path):
"""Process audio from file path"""
logger.info(f"Processing audio file: {audio_path}")
try:
# Get audio processor sampling rate
global processor
if processor is None:
processor = AutoProcessor.from_pretrained(MODEL_ID)
target_sr = processor.feature_extractor.sampling_rate
# Load and resample audio
audio_data, sr = librosa.load(audio_path, sr=target_sr)
# Limit to 30 seconds for better performance
max_samples = 30 * target_sr
if len(audio_data) > max_samples:
logger.info(f"Truncating audio to 30 seconds")
audio_data = audio_data[:max_samples]
return audio_data
except Exception as e:
logger.error(f"Error processing audio file: {e}")
return None
def generate_response(audio_path, message, chat_history=None):
"""Generate response using the model based on the official Qwen example"""
global model, processor
try:
if model is None or processor is None:
model, processor = load_model()
# Build the conversation following the Qwen example format
conversation = []
# Add system prompt
system_prompt = "You are an expert audio engineer assisting with music production and mixing. Provide clear, specific advice on audio engineering techniques, mixing adjustments, and production decisions based on the audio samples and the user's questions. Focus on practical, actionable guidance."
conversation.append({"role": "system", "content": system_prompt})
# Process chat history - which may contain structured messages with audio
if chat_history:
for user_msg, bot_msg in chat_history:
# Add user message
if user_msg is not None:
conversation.append({"role": "user", "content": user_msg})
# Add assistant message
if bot_msg and bot_msg != "⏳ Generating response, please wait...":
conversation.append({"role": "assistant", "content": bot_msg})
# Add current message with audio if not already in a structured format
if isinstance(message, str):
if audio_path:
conversation.append({
"role": "user",
"content": [
{"type": "audio", "audio_url": audio_path},
{"type": "text", "text": message}
]
})
else:
conversation.append({"role": "user", "content": message})
else:
# Message is already structured
conversation.append({"role": "user", "content": message})
# Log the conversation for debugging
logger.info(f"Conversation structure being sent to model: {conversation}")
# Now collect all audio files following the Qwen example
audios = []
for msg in conversation:
if isinstance(msg["content"], list):
for ele in msg["content"]:
if ele["type"] == "audio":
# Process audio from file path
logger.info(f"Processing audio from path: {ele['audio_url']}")
audio_data = process_audio_file(ele['audio_url'])
if audio_data is not None:
audios.append(audio_data)
logger.info(f"Added audio: length={len(audio_data)}")
else:
logger.error(f"Failed to process audio from {ele['audio_url']}")
logger.info(f"Number of audio inputs collected: {len(audios)}")
# Apply chat template
text = processor.apply_chat_template(
conversation, add_generation_prompt=True, tokenize=False
)
logger.info(f"Templated conversation (truncated): {text[:300]}...")
# Process inputs exactly as in the Qwen example
inputs = processor(
text=text,
audios=audios if audios else None,
return_tensors="pt",
padding=True
)
# Log inputs shape
logger.info(f"Input shape: {inputs.input_ids.shape}")
if 'audio_features' in inputs:
logger.info(f"Audio features shape: {inputs['audio_features'].shape}")
else:
logger.warning("No audio features in inputs!")
# Move inputs to the same device as model
device = next(model.parameters()).device
inputs = {k: v.to(device) for k, v in inputs.items()}
log_gpu_memory("Before generation")
# Generate response
with torch.no_grad():
generate_ids = model.generate(
**inputs,
max_new_tokens=300,
do_sample=True,
temperature=0.7,
top_p=0.9,
pad_token_id=processor.tokenizer.pad_token_id
)
# Extract only the generated part (not the input)
generate_ids = generate_ids[:, inputs["input_ids"].shape[1]:]
# Decode the response
response = processor.batch_decode(
generate_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=False
)[0]
del inputs, generate_ids
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
log_gpu_memory("After generation")
return response
except Exception as e:
logger.error(f"Error generating response: {e}")
import traceback
logger.error(traceback.format_exc())
return f"I encountered an error while processing your request: {str(e)}"
# Create Gradio Interface
def create_interface():
"""Create the Gradio interface"""
with gr.Blocks(title="Music Mixing Assistant") as demo:
gr.Markdown("# 🎧 Music Mixing Assistant")
# Chat state
audio_path_state = gr.State("")
audio_loaded_state = gr.State(False)
with gr.Row():
with gr.Column(scale=2):
# Chat interface
chatbot = gr.Chatbot(height=400, label="Conversation")
with gr.Row():
msg = gr.Textbox(
placeholder="Ask about your mix...",
show_label=False,
container=False,
interactive=False # Disabled until audio is loaded
)
submit_btn = gr.Button("Send", variant="primary", interactive=False) # Disabled until audio is loaded
with gr.Row():
clear_btn = gr.Button("Clear Chat")
with gr.Column(scale=1):
# Audio file upload
audio_input = gr.Audio(
label="Upload Audio",
type="filepath",
sources=["upload", "microphone"]
)
gr.Markdown("*Upload an audio file (WAV or MP3, 30 seconds will be analyzed)*")
# Set audio button
set_audio_btn = gr.Button("Set Audio Track", variant="primary")
# Status display
status = gr.Markdown("*⚠️ Please load an audio file before chatting*")
# Set audio handler
def set_audio(filepath):
"""Set the audio filepath and process audio data"""
if not filepath:
return "", False, "*⚠️ Please upload an audio file*", gr.update(interactive=False), gr.update(interactive=False)
try:
# Return success and enable chat input
return filepath, True, "*✅ Audio loaded successfully! You can start chatting now.*", gr.update(interactive=True), gr.update(interactive=True)
except Exception as e:
return "", False, f"*❌ Error: {str(e)}*", gr.update(interactive=False), gr.update(interactive=False)
set_audio_btn.click(
set_audio,
inputs=[audio_input],
outputs=[audio_path_state, audio_loaded_state, status, msg, submit_btn]
)
# Chat response handler
def chat_response(message, chat_history, audio_path, audio_loaded):
"""Handle chat message and generate response"""
if not message or not message.strip():
return chat_history, "", gr.update()
if not audio_loaded:
chat_history.append((message, "Please load an audio file before chatting."))
return chat_history, "", gr.update()
# Store current user message format for use in the next turn
current_user_message = None
if audio_path and audio_loaded:
# Format with audio
current_user_message = [
{"type": "audio", "audio_url": audio_path},
{"type": "text", "text": message}
]
else:
# Text only
current_user_message = message
# Add user message to chat history
chat_history.append((current_user_message, None))
# Create a display version of chat history for UI
display_history = []
for user_msg, bot_msg in chat_history:
# For display purposes, convert structured messages to text
display_user_msg = user_msg
if isinstance(user_msg, list):
# Extract just the text part for display
for item in user_msg:
if item.get("type") == "text":
display_user_msg = item.get("text", "")
break
display_history.append((display_user_msg, bot_msg))
# Add loading message to display history
loading_display = display_history.copy()
loading_display[-1] = (loading_display[-1][0], "⏳ Generating response, please wait...")
yield loading_display, "", gr.update(value="*Processing audio analysis...*")
try:
# Log what we're sending to generate_response
logger.info(f"Sending to generate_response - message: {message}, chat_history: {chat_history[:-1]}")
# Generate response using the structured chat history
response = generate_response(audio_path, message, chat_history[:-1])
# Update the real chat history with the response
chat_history[-1] = (current_user_message, response)
# Update display history
display_history[-1] = (display_history[-1][0], response)
yield display_history, "", gr.update(value="*✅ Audio loaded successfully! You can start chatting now.*")
except Exception as e:
# Update with error
logger.error(f"Error in chat_response: {str(e)}")
import traceback
logger.error(traceback.format_exc())
chat_history[-1] = (current_user_message, f"Error: {str(e)}")
display_history[-1] = (display_history[-1][0], f"Error: {str(e)}")
yield display_history, "", gr.update(value="*✅ Audio loaded successfully! You can start chatting now.*")
# Connect submit button
submit_btn.click(
chat_response,
inputs=[msg, chatbot, audio_path_state, audio_loaded_state],
outputs=[chatbot, msg, status],
show_progress="full" # Show loading indicator during processing
)
# Connect message box submit
msg.submit(
chat_response,
inputs=[msg, chatbot, audio_path_state, audio_loaded_state],
outputs=[chatbot, msg, status],
show_progress="full" # Show loading indicator during processing
)
# Clear button
def clear_chat():
"""Clear the chat and keep audio loaded state"""
return [], "", "*✅ Chat cleared. Audio file remains loaded.*"
clear_btn.click(
clear_chat,
outputs=[chatbot, msg, status]
)
return demo
# Launch app
if __name__ == "__main__":
logger.info(f"Starting Music Mixing Assistant with model {MODEL_ID}")
logger.info(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
logger.info(f"GPU: {torch.cuda.get_device_name(0)}")
logger.info(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
demo = create_interface()
demo.queue(max_size=5).launch(show_error=True)