File size: 15,925 Bytes
8d531a8
ba95477
 
2be10dd
8d531a8
 
7b77fd5
 
 
 
 
f533e2c
2be10dd
7b77fd5
8d531a8
 
ba95477
8d531a8
 
 
2be10dd
7b77fd5
ea3f0e6
2be10dd
7b77fd5
 
 
 
 
 
8d531a8
7b77fd5
 
 
 
 
 
 
 
 
 
 
edb4194
8d531a8
7b77fd5
 
 
 
 
 
 
 
 
567c1ca
 
7324297
5e824e3
c1a0ce1
 
 
567c1ca
 
 
 
 
 
 
 
 
5e824e3
 
c1a0ce1
 
53c24d3
ed55f0b
7b77fd5
 
8d531a8
7b77fd5
 
567c1ca
74ca2c9
 
 
edb4194
7b77fd5
 
 
 
ba95477
7b77fd5
 
 
 
74ca2c9
7b77fd5
c3bc466
 
7b77fd5
c3bc466
7b77fd5
 
 
 
74ca2c9
7b77fd5
 
f68734c
5d08ae3
7324297
74ca2c9
7b77fd5
 
 
74ca2c9
5d08ae3
74ca2c9
5d08ae3
 
7b77fd5
f68734c
 
5d08ae3
f68734c
5d08ae3
 
 
288729f
5d08ae3
 
f68734c
 
 
5d08ae3
 
 
d11be63
 
 
5d08ae3
d11be63
 
 
 
 
f68734c
5d08ae3
f68734c
5d08ae3
 
 
f68734c
5d08ae3
 
e3d7dfa
 
5d08ae3
 
 
 
 
 
 
 
 
 
 
 
e3d7dfa
5d08ae3
7b77fd5
74ca2c9
7b77fd5
5d08ae3
e3d7dfa
 
5d08ae3
7b77fd5
 
74ca2c9
7b77fd5
5d08ae3
7b77fd5
74ca2c9
5d08ae3
 
d11be63
 
 
5d08ae3
d11be63
88dce5c
 
 
74ca2c9
88dce5c
74ca2c9
5d08ae3
88dce5c
5d08ae3
88dce5c
3228503
88dce5c
 
 
 
 
5d08ae3
 
 
 
 
 
 
 
 
88dce5c
74ca2c9
5d08ae3
88dce5c
 
 
74ca2c9
88dce5c
74ca2c9
5d08ae3
74ca2c9
88dce5c
 
5d08ae3
 
88dce5c
 
7b77fd5
ba95477
7b77fd5
 
 
 
 
f68734c
 
7b77fd5
 
 
 
288729f
7b77fd5
 
 
 
 
74ca2c9
 
7b77fd5
74ca2c9
7b77fd5
 
 
61144f6
7b77fd5
74ca2c9
 
 
 
025287d
7b77fd5
c3bc466
7b77fd5
 
 
 
 
74ca2c9
7b77fd5
 
f68734c
 
 
 
61144f6
7b77fd5
f68734c
 
7b77fd5
f68734c
61144f6
7b77fd5
 
 
f68734c
7b77fd5
 
 
f68734c
7b77fd5
 
74ca2c9
 
 
 
 
33edd95
5d08ae3
 
 
 
 
 
 
 
 
 
 
74ca2c9
5d08ae3
 
e3d7dfa
5d08ae3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61144f6
7b77fd5
5d08ae3
 
e3d7dfa
5d08ae3
 
ba95477
e3d7dfa
5d08ae3
 
 
 
 
 
7b77fd5
e3d7dfa
 
5d08ae3
 
 
 
 
 
 
e3d7dfa
7b77fd5
 
 
f68734c
 
 
7b77fd5
 
 
 
 
f68734c
 
 
7b77fd5
 
 
74ca2c9
 
 
 
7b77fd5
74ca2c9
7b77fd5
 
61144f6
ba95477
61144f6
7b77fd5
61144f6
7b77fd5
 
 
 
 
 
61144f6
ba95477
7b77fd5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
import os
import sys
import logging
import gradio as gr
import torch
import numpy as np
import librosa
import requests
from io import BytesIO
from urllib.request import urlopen, Request
import gc
from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger(__name__)

# Use the correct model ID
MODEL_ID = "mclemcrew/Qwen-Audio-Mix-Instruct"

# Global variables for model and processor
model = None
processor = None

def log_gpu_memory(message=""):
    """Log GPU memory usage"""
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / 1024**3
        reserved = torch.cuda.memory_reserved() / 1024**3
        logger.info(f"{message} - GPU: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved")

def load_model():
    """Load the model and processor with better error handling"""
    global model, processor
    
    # Return if already loaded
    if model is not None and processor is not None:
        return model, processor
    
    try:
        # First clear CUDA cache
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        
        # Load processor first
        logger.info(f"Loading processor from {MODEL_ID}")
        processor = AutoProcessor.from_pretrained(MODEL_ID)
        logger.info("Processor loaded successfully")
        
        # Force disable bitsandbytes integration
        os.environ["DISABLE_BITSANDBYTES_CUDA_SETUP"] = "TRUE"
        
        if torch.cuda.is_available():
            gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1024**3
            logger.info(f"GPU memory: {gpu_memory:.2f} GB")
            
            # Load directly with FP16 but without 8-bit quantization
            logger.info("Loading model with FP16 precision")
            model = Qwen2AudioForConditionalGeneration.from_pretrained(
                MODEL_ID,
                torch_dtype=torch.float16,
                device_map="auto",
                quantization_config=None,  # Explicitly disable quantization
                low_cpu_mem_usage=True
            )
        else:
            # Load on CPU if no GPU
            logger.info("Loading model on CPU")
            model = Qwen2AudioForConditionalGeneration.from_pretrained(MODEL_ID)
        
        model.eval()
        log_gpu_memory("After model loading")
        return model, processor
    except Exception as e:
        logger.error(f"Error loading model or processor: {e}")
        raise

def process_audio_file(audio_path):
    """Process audio from file path"""
    logger.info(f"Processing audio file: {audio_path}")
    
    try:
        # Get audio processor sampling rate
        global processor
        if processor is None:
            processor = AutoProcessor.from_pretrained(MODEL_ID)
        
        target_sr = processor.feature_extractor.sampling_rate
        
        # Load and resample audio
        audio_data, sr = librosa.load(audio_path, sr=target_sr)
        
        # Limit to 30 seconds for better performance
        max_samples = 30 * target_sr
        if len(audio_data) > max_samples:
            logger.info(f"Truncating audio to 30 seconds")
            audio_data = audio_data[:max_samples]
        
        return audio_data
    except Exception as e:
        logger.error(f"Error processing audio file: {e}")
        return None

def generate_response(audio_path, message, chat_history=None):
    """Generate response using the model based on the official Qwen example"""
    global model, processor

    try:
        if model is None or processor is None:
            model, processor = load_model()

        # Build the conversation following the Qwen example format
        conversation = []
        
        # Add system prompt
        system_prompt = "You are an expert audio engineer assisting with music production and mixing. Provide clear, specific advice on audio engineering techniques, mixing adjustments, and production decisions based on the audio samples and the user's questions. Focus on practical, actionable guidance."
        conversation.append({"role": "system", "content": system_prompt})

        # Process chat history - which may contain structured messages with audio
        if chat_history:
            for user_msg, bot_msg in chat_history:
                # Add user message
                if user_msg is not None:
                    conversation.append({"role": "user", "content": user_msg})
                
                # Add assistant message
                if bot_msg and bot_msg != "⏳ Generating response, please wait...":
                    conversation.append({"role": "assistant", "content": bot_msg})

        # Add current message with audio if not already in a structured format
        if isinstance(message, str):
            if audio_path:
                conversation.append({
                    "role": "user",
                    "content": [
                        {"type": "audio", "audio_url": audio_path},
                        {"type": "text", "text": message}
                    ]
                })
            else:
                conversation.append({"role": "user", "content": message})
        else:
            # Message is already structured
            conversation.append({"role": "user", "content": message})
        
        # Log the conversation for debugging
        logger.info(f"Conversation structure being sent to model: {conversation}")

        # Now collect all audio files following the Qwen example
        audios = []
        for msg in conversation:
            if isinstance(msg["content"], list):
                for ele in msg["content"]:
                    if ele["type"] == "audio":
                        # Process audio from file path
                        logger.info(f"Processing audio from path: {ele['audio_url']}")
                        audio_data = process_audio_file(ele['audio_url'])
                        if audio_data is not None:
                            audios.append(audio_data)
                            logger.info(f"Added audio: length={len(audio_data)}")
                        else:
                            logger.error(f"Failed to process audio from {ele['audio_url']}")

        logger.info(f"Number of audio inputs collected: {len(audios)}")

        # Apply chat template
        text = processor.apply_chat_template(
            conversation, add_generation_prompt=True, tokenize=False
        )
        
        logger.info(f"Templated conversation (truncated): {text[:300]}...")

        # Process inputs exactly as in the Qwen example
        inputs = processor(
            text=text,
            audios=audios if audios else None,
            return_tensors="pt",
            padding=True
        )

        # Log inputs shape
        logger.info(f"Input shape: {inputs.input_ids.shape}")
        if 'audio_features' in inputs:
            logger.info(f"Audio features shape: {inputs['audio_features'].shape}")
        else:
            logger.warning("No audio features in inputs!")

        # Move inputs to the same device as model
        device = next(model.parameters()).device
        inputs = {k: v.to(device) for k, v in inputs.items()}

        log_gpu_memory("Before generation")

        # Generate response
        with torch.no_grad():
            generate_ids = model.generate(
                **inputs,
                max_new_tokens=300,
                do_sample=True,
                temperature=0.7,
                top_p=0.9,
                pad_token_id=processor.tokenizer.pad_token_id
            )
            
        # Extract only the generated part (not the input)
        generate_ids = generate_ids[:, inputs["input_ids"].shape[1]:]
        
        # Decode the response
        response = processor.batch_decode(
            generate_ids, 
            skip_special_tokens=True,
            clean_up_tokenization_spaces=False
        )[0]

        del inputs, generate_ids
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        log_gpu_memory("After generation")

        return response

    except Exception as e:
        logger.error(f"Error generating response: {e}")
        import traceback
        logger.error(traceback.format_exc())
        return f"I encountered an error while processing your request: {str(e)}"

# Create Gradio Interface
def create_interface():
    """Create the Gradio interface"""
    with gr.Blocks(title="Music Mixing Assistant") as demo:
        gr.Markdown("# 🎧 Music Mixing Assistant")
        
        # Chat state
        audio_path_state = gr.State("")
        audio_loaded_state = gr.State(False)
        
        with gr.Row():
            with gr.Column(scale=2):
                # Chat interface
                chatbot = gr.Chatbot(height=400, label="Conversation")
                
                with gr.Row():
                    msg = gr.Textbox(
                        placeholder="Ask about your mix...",
                        show_label=False,
                        container=False,
                        interactive=False  # Disabled until audio is loaded
                    )
                    submit_btn = gr.Button("Send", variant="primary", interactive=False)  # Disabled until audio is loaded
                
                with gr.Row():
                    clear_btn = gr.Button("Clear Chat")
            
            with gr.Column(scale=1):
                # Audio file upload
                audio_input = gr.Audio(
                    label="Upload Audio",
                    type="filepath",
                    sources=["upload", "microphone"]
                )
                gr.Markdown("*Upload an audio file (WAV or MP3, 30 seconds will be analyzed)*")
                
                # Set audio button
                set_audio_btn = gr.Button("Set Audio Track", variant="primary")
                
                # Status display
                status = gr.Markdown("*⚠️ Please load an audio file before chatting*")
        
        # Set audio handler
        def set_audio(filepath):
            """Set the audio filepath and process audio data"""
            if not filepath:
                return "", False, "*⚠️ Please upload an audio file*", gr.update(interactive=False), gr.update(interactive=False)
            
            try:
                # Return success and enable chat input
                return filepath, True, "*✅ Audio loaded successfully! You can start chatting now.*", gr.update(interactive=True), gr.update(interactive=True)
            except Exception as e:
                return "", False, f"*❌ Error: {str(e)}*", gr.update(interactive=False), gr.update(interactive=False)
        
        set_audio_btn.click(
            set_audio,
            inputs=[audio_input],
            outputs=[audio_path_state, audio_loaded_state, status, msg, submit_btn]
        )
        
        # Chat response handler
        def chat_response(message, chat_history, audio_path, audio_loaded):
            """Handle chat message and generate response"""
            if not message or not message.strip():
                return chat_history, "", gr.update()
            
            if not audio_loaded:
                chat_history.append((message, "Please load an audio file before chatting."))
                return chat_history, "", gr.update()
            
            # Store current user message format for use in the next turn
            current_user_message = None
            if audio_path and audio_loaded:
                # Format with audio
                current_user_message = [
                    {"type": "audio", "audio_url": audio_path},
                    {"type": "text", "text": message}
                ]
            else:
                # Text only
                current_user_message = message
            
            # Add user message to chat history
            chat_history.append((current_user_message, None))
            
            # Create a display version of chat history for UI
            display_history = []
            for user_msg, bot_msg in chat_history:
                # For display purposes, convert structured messages to text
                display_user_msg = user_msg
                if isinstance(user_msg, list):
                    # Extract just the text part for display
                    for item in user_msg:
                        if item.get("type") == "text":
                            display_user_msg = item.get("text", "")
                            break
                
                display_history.append((display_user_msg, bot_msg))
            
            # Add loading message to display history
            loading_display = display_history.copy()
            loading_display[-1] = (loading_display[-1][0], "⏳ Generating response, please wait...")
            
            yield loading_display, "", gr.update(value="*Processing audio analysis...*")
            
            try:
                # Log what we're sending to generate_response
                logger.info(f"Sending to generate_response - message: {message}, chat_history: {chat_history[:-1]}")
                
                # Generate response using the structured chat history
                response = generate_response(audio_path, message, chat_history[:-1])
                
                # Update the real chat history with the response
                chat_history[-1] = (current_user_message, response)
                
                # Update display history
                display_history[-1] = (display_history[-1][0], response)
                
                yield display_history, "", gr.update(value="*✅ Audio loaded successfully! You can start chatting now.*")
            except Exception as e:
                # Update with error
                logger.error(f"Error in chat_response: {str(e)}")
                import traceback
                logger.error(traceback.format_exc())
                
                chat_history[-1] = (current_user_message, f"Error: {str(e)}")
                display_history[-1] = (display_history[-1][0], f"Error: {str(e)}")
                
                yield display_history, "", gr.update(value="*✅ Audio loaded successfully! You can start chatting now.*")

        # Connect submit button
        submit_btn.click(
            chat_response,
            inputs=[msg, chatbot, audio_path_state, audio_loaded_state],
            outputs=[chatbot, msg, status],
            show_progress="full"  # Show loading indicator during processing
        )
        
        # Connect message box submit
        msg.submit(
            chat_response,
            inputs=[msg, chatbot, audio_path_state, audio_loaded_state],
            outputs=[chatbot, msg, status],
            show_progress="full"  # Show loading indicator during processing
        )
        
        # Clear button
        def clear_chat():
            """Clear the chat and keep audio loaded state"""
            return [], "", "*✅ Chat cleared. Audio file remains loaded.*"
        
        clear_btn.click(
            clear_chat,
            outputs=[chatbot, msg, status]
        )
    
    return demo

# Launch app
if __name__ == "__main__":
    logger.info(f"Starting Music Mixing Assistant with model {MODEL_ID}")
    logger.info(f"CUDA available: {torch.cuda.is_available()}")
    
    if torch.cuda.is_available():
        logger.info(f"GPU: {torch.cuda.get_device_name(0)}")
        logger.info(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
    
    demo = create_interface()
    demo.queue(max_size=5).launch(show_error=True)