Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| import logging | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import librosa | |
| import requests | |
| from io import BytesIO | |
| from urllib.request import urlopen, Request | |
| import gc | |
| from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s', | |
| handlers=[logging.StreamHandler(sys.stdout)] | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Use the correct model ID | |
| MODEL_ID = "mclemcrew/Qwen-Audio-Mix-Instruct" | |
| # Global variables for model and processor | |
| model = None | |
| processor = None | |
| def log_gpu_memory(message=""): | |
| """Log GPU memory usage""" | |
| if torch.cuda.is_available(): | |
| allocated = torch.cuda.memory_allocated() / 1024**3 | |
| reserved = torch.cuda.memory_reserved() / 1024**3 | |
| logger.info(f"{message} - GPU: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved") | |
| def load_model(): | |
| """Load the model and processor with better error handling""" | |
| global model, processor | |
| # Return if already loaded | |
| if model is not None and processor is not None: | |
| return model, processor | |
| try: | |
| # First clear CUDA cache | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # Load processor first | |
| logger.info(f"Loading processor from {MODEL_ID}") | |
| processor = AutoProcessor.from_pretrained(MODEL_ID) | |
| logger.info("Processor loaded successfully") | |
| # Force disable bitsandbytes integration | |
| os.environ["DISABLE_BITSANDBYTES_CUDA_SETUP"] = "TRUE" | |
| if torch.cuda.is_available(): | |
| gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3 | |
| logger.info(f"GPU memory: {gpu_memory:.2f} GB") | |
| # Load directly with FP16 but without 8-bit quantization | |
| logger.info("Loading model with FP16 precision") | |
| model = Qwen2AudioForConditionalGeneration.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| quantization_config=None, # Explicitly disable quantization | |
| low_cpu_mem_usage=True | |
| ) | |
| else: | |
| # Load on CPU if no GPU | |
| logger.info("Loading model on CPU") | |
| model = Qwen2AudioForConditionalGeneration.from_pretrained(MODEL_ID) | |
| model.eval() | |
| log_gpu_memory("After model loading") | |
| return model, processor | |
| except Exception as e: | |
| logger.error(f"Error loading model or processor: {e}") | |
| raise | |
| def process_audio_file(audio_path): | |
| """Process audio from file path""" | |
| logger.info(f"Processing audio file: {audio_path}") | |
| try: | |
| # Get audio processor sampling rate | |
| global processor | |
| if processor is None: | |
| processor = AutoProcessor.from_pretrained(MODEL_ID) | |
| target_sr = processor.feature_extractor.sampling_rate | |
| # Load and resample audio | |
| audio_data, sr = librosa.load(audio_path, sr=target_sr) | |
| # Limit to 30 seconds for better performance | |
| max_samples = 30 * target_sr | |
| if len(audio_data) > max_samples: | |
| logger.info(f"Truncating audio to 30 seconds") | |
| audio_data = audio_data[:max_samples] | |
| return audio_data | |
| except Exception as e: | |
| logger.error(f"Error processing audio file: {e}") | |
| return None | |
| def generate_response(audio_path, message, chat_history=None): | |
| """Generate response using the model based on the official Qwen example""" | |
| global model, processor | |
| try: | |
| if model is None or processor is None: | |
| model, processor = load_model() | |
| # Build the conversation following the Qwen example format | |
| conversation = [] | |
| # Add system prompt | |
| system_prompt = "You are an expert audio engineer assisting with music production and mixing. Provide clear, specific advice on audio engineering techniques, mixing adjustments, and production decisions based on the audio samples and the user's questions. Focus on practical, actionable guidance." | |
| conversation.append({"role": "system", "content": system_prompt}) | |
| # Process chat history - which may contain structured messages with audio | |
| if chat_history: | |
| for user_msg, bot_msg in chat_history: | |
| # Add user message | |
| if user_msg is not None: | |
| conversation.append({"role": "user", "content": user_msg}) | |
| # Add assistant message | |
| if bot_msg and bot_msg != "⏳ Generating response, please wait...": | |
| conversation.append({"role": "assistant", "content": bot_msg}) | |
| # Add current message with audio if not already in a structured format | |
| if isinstance(message, str): | |
| if audio_path: | |
| conversation.append({ | |
| "role": "user", | |
| "content": [ | |
| {"type": "audio", "audio_url": audio_path}, | |
| {"type": "text", "text": message} | |
| ] | |
| }) | |
| else: | |
| conversation.append({"role": "user", "content": message}) | |
| else: | |
| # Message is already structured | |
| conversation.append({"role": "user", "content": message}) | |
| # Log the conversation for debugging | |
| logger.info(f"Conversation structure being sent to model: {conversation}") | |
| # Now collect all audio files following the Qwen example | |
| audios = [] | |
| for msg in conversation: | |
| if isinstance(msg["content"], list): | |
| for ele in msg["content"]: | |
| if ele["type"] == "audio": | |
| # Process audio from file path | |
| logger.info(f"Processing audio from path: {ele['audio_url']}") | |
| audio_data = process_audio_file(ele['audio_url']) | |
| if audio_data is not None: | |
| audios.append(audio_data) | |
| logger.info(f"Added audio: length={len(audio_data)}") | |
| else: | |
| logger.error(f"Failed to process audio from {ele['audio_url']}") | |
| logger.info(f"Number of audio inputs collected: {len(audios)}") | |
| # Apply chat template | |
| text = processor.apply_chat_template( | |
| conversation, add_generation_prompt=True, tokenize=False | |
| ) | |
| logger.info(f"Templated conversation (truncated): {text[:300]}...") | |
| # Process inputs exactly as in the Qwen example | |
| inputs = processor( | |
| text=text, | |
| audios=audios if audios else None, | |
| return_tensors="pt", | |
| padding=True | |
| ) | |
| # Log inputs shape | |
| logger.info(f"Input shape: {inputs.input_ids.shape}") | |
| if 'audio_features' in inputs: | |
| logger.info(f"Audio features shape: {inputs['audio_features'].shape}") | |
| else: | |
| logger.warning("No audio features in inputs!") | |
| # Move inputs to the same device as model | |
| device = next(model.parameters()).device | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| log_gpu_memory("Before generation") | |
| # Generate response | |
| with torch.no_grad(): | |
| generate_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=300, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.9, | |
| pad_token_id=processor.tokenizer.pad_token_id | |
| ) | |
| # Extract only the generated part (not the input) | |
| generate_ids = generate_ids[:, inputs["input_ids"].shape[1]:] | |
| # Decode the response | |
| response = processor.batch_decode( | |
| generate_ids, | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces=False | |
| )[0] | |
| del inputs, generate_ids | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| log_gpu_memory("After generation") | |
| return response | |
| except Exception as e: | |
| logger.error(f"Error generating response: {e}") | |
| import traceback | |
| logger.error(traceback.format_exc()) | |
| return f"I encountered an error while processing your request: {str(e)}" | |
| # Create Gradio Interface | |
| def create_interface(): | |
| """Create the Gradio interface""" | |
| with gr.Blocks(title="Music Mixing Assistant") as demo: | |
| gr.Markdown("# 🎧 Music Mixing Assistant") | |
| # Chat state | |
| audio_path_state = gr.State("") | |
| audio_loaded_state = gr.State(False) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| # Chat interface | |
| chatbot = gr.Chatbot(height=400, label="Conversation") | |
| with gr.Row(): | |
| msg = gr.Textbox( | |
| placeholder="Ask about your mix...", | |
| show_label=False, | |
| container=False, | |
| interactive=False # Disabled until audio is loaded | |
| ) | |
| submit_btn = gr.Button("Send", variant="primary", interactive=False) # Disabled until audio is loaded | |
| with gr.Row(): | |
| clear_btn = gr.Button("Clear Chat") | |
| with gr.Column(scale=1): | |
| # Audio file upload | |
| audio_input = gr.Audio( | |
| label="Upload Audio", | |
| type="filepath", | |
| sources=["upload", "microphone"] | |
| ) | |
| gr.Markdown("*Upload an audio file (WAV or MP3, 30 seconds will be analyzed)*") | |
| # Set audio button | |
| set_audio_btn = gr.Button("Set Audio Track", variant="primary") | |
| # Status display | |
| status = gr.Markdown("*⚠️ Please load an audio file before chatting*") | |
| # Set audio handler | |
| def set_audio(filepath): | |
| """Set the audio filepath and process audio data""" | |
| if not filepath: | |
| return "", False, "*⚠️ Please upload an audio file*", gr.update(interactive=False), gr.update(interactive=False) | |
| try: | |
| # Return success and enable chat input | |
| return filepath, True, "*✅ Audio loaded successfully! You can start chatting now.*", gr.update(interactive=True), gr.update(interactive=True) | |
| except Exception as e: | |
| return "", False, f"*❌ Error: {str(e)}*", gr.update(interactive=False), gr.update(interactive=False) | |
| set_audio_btn.click( | |
| set_audio, | |
| inputs=[audio_input], | |
| outputs=[audio_path_state, audio_loaded_state, status, msg, submit_btn] | |
| ) | |
| # Chat response handler | |
| def chat_response(message, chat_history, audio_path, audio_loaded): | |
| """Handle chat message and generate response""" | |
| if not message or not message.strip(): | |
| return chat_history, "", gr.update() | |
| if not audio_loaded: | |
| chat_history.append((message, "Please load an audio file before chatting.")) | |
| return chat_history, "", gr.update() | |
| # Store current user message format for use in the next turn | |
| current_user_message = None | |
| if audio_path and audio_loaded: | |
| # Format with audio | |
| current_user_message = [ | |
| {"type": "audio", "audio_url": audio_path}, | |
| {"type": "text", "text": message} | |
| ] | |
| else: | |
| # Text only | |
| current_user_message = message | |
| # Add user message to chat history | |
| chat_history.append((current_user_message, None)) | |
| # Create a display version of chat history for UI | |
| display_history = [] | |
| for user_msg, bot_msg in chat_history: | |
| # For display purposes, convert structured messages to text | |
| display_user_msg = user_msg | |
| if isinstance(user_msg, list): | |
| # Extract just the text part for display | |
| for item in user_msg: | |
| if item.get("type") == "text": | |
| display_user_msg = item.get("text", "") | |
| break | |
| display_history.append((display_user_msg, bot_msg)) | |
| # Add loading message to display history | |
| loading_display = display_history.copy() | |
| loading_display[-1] = (loading_display[-1][0], "⏳ Generating response, please wait...") | |
| yield loading_display, "", gr.update(value="*Processing audio analysis...*") | |
| try: | |
| # Log what we're sending to generate_response | |
| logger.info(f"Sending to generate_response - message: {message}, chat_history: {chat_history[:-1]}") | |
| # Generate response using the structured chat history | |
| response = generate_response(audio_path, message, chat_history[:-1]) | |
| # Update the real chat history with the response | |
| chat_history[-1] = (current_user_message, response) | |
| # Update display history | |
| display_history[-1] = (display_history[-1][0], response) | |
| yield display_history, "", gr.update(value="*✅ Audio loaded successfully! You can start chatting now.*") | |
| except Exception as e: | |
| # Update with error | |
| logger.error(f"Error in chat_response: {str(e)}") | |
| import traceback | |
| logger.error(traceback.format_exc()) | |
| chat_history[-1] = (current_user_message, f"Error: {str(e)}") | |
| display_history[-1] = (display_history[-1][0], f"Error: {str(e)}") | |
| yield display_history, "", gr.update(value="*✅ Audio loaded successfully! You can start chatting now.*") | |
| # Connect submit button | |
| submit_btn.click( | |
| chat_response, | |
| inputs=[msg, chatbot, audio_path_state, audio_loaded_state], | |
| outputs=[chatbot, msg, status], | |
| show_progress="full" # Show loading indicator during processing | |
| ) | |
| # Connect message box submit | |
| msg.submit( | |
| chat_response, | |
| inputs=[msg, chatbot, audio_path_state, audio_loaded_state], | |
| outputs=[chatbot, msg, status], | |
| show_progress="full" # Show loading indicator during processing | |
| ) | |
| # Clear button | |
| def clear_chat(): | |
| """Clear the chat and keep audio loaded state""" | |
| return [], "", "*✅ Chat cleared. Audio file remains loaded.*" | |
| clear_btn.click( | |
| clear_chat, | |
| outputs=[chatbot, msg, status] | |
| ) | |
| return demo | |
| # Launch app | |
| if __name__ == "__main__": | |
| logger.info(f"Starting Music Mixing Assistant with model {MODEL_ID}") | |
| logger.info(f"CUDA available: {torch.cuda.is_available()}") | |
| if torch.cuda.is_available(): | |
| logger.info(f"GPU: {torch.cuda.get_device_name(0)}") | |
| logger.info(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB") | |
| demo = create_interface() | |
| demo.queue(max_size=5).launch(show_error=True) |