Spaces:
Sleeping
Sleeping
updates to gradio app
Browse files
app.py
CHANGED
|
@@ -1,64 +1,618 @@
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
| 8 |
|
|
|
|
|
|
|
| 9 |
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
system_message,
|
| 14 |
-
max_tokens,
|
| 15 |
-
temperature,
|
| 16 |
-
top_p,
|
| 17 |
-
):
|
| 18 |
-
messages = [{"role": "system", "content": system_message}]
|
| 19 |
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
|
|
|
| 25 |
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
)
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
-
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
"""
|
| 44 |
-
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
|
| 45 |
-
"""
|
| 46 |
-
demo = gr.ChatInterface(
|
| 47 |
-
respond,
|
| 48 |
-
additional_inputs=[
|
| 49 |
-
gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
|
| 50 |
-
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
|
| 51 |
-
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
|
| 52 |
-
gr.Slider(
|
| 53 |
-
minimum=0.1,
|
| 54 |
-
maximum=1.0,
|
| 55 |
-
value=0.95,
|
| 56 |
-
step=0.05,
|
| 57 |
-
label="Top-p (nucleus sampling)",
|
| 58 |
-
),
|
| 59 |
-
],
|
| 60 |
-
)
|
| 61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
|
|
|
| 63 |
if __name__ == "__main__":
|
| 64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
import gradio as gr
|
| 3 |
+
import spaces
|
| 4 |
+
import torch
|
| 5 |
+
from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration, BitsAndBytesConfig
|
| 6 |
+
import numpy as np
|
| 7 |
+
import librosa
|
| 8 |
+
from urllib.request import urlopen
|
| 9 |
+
from io import BytesIO
|
| 10 |
+
import logging
|
| 11 |
+
import sys
|
| 12 |
+
import gc
|
| 13 |
|
| 14 |
+
# Configure logging
|
| 15 |
+
logging.basicConfig(
|
| 16 |
+
level=logging.INFO,
|
| 17 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
| 18 |
+
handlers=[logging.StreamHandler(sys.stdout)]
|
| 19 |
+
)
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
|
| 22 |
+
# Update to use the merged model
|
| 23 |
+
MODEL_ID = "mclemcrew/Qwen-Audio-Mix-Instruct"
|
| 24 |
|
| 25 |
+
# Cache for model and processor
|
| 26 |
+
model_cache = None
|
| 27 |
+
processor_cache = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
+
# Memory tracking
|
| 30 |
+
def log_gpu_memory(message=""):
|
| 31 |
+
if torch.cuda.is_available():
|
| 32 |
+
allocated = torch.cuda.memory_allocated() / 1024**3
|
| 33 |
+
reserved = torch.cuda.memory_reserved() / 1024**3
|
| 34 |
+
logger.info(f"{message} - GPU Memory: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved")
|
| 35 |
|
| 36 |
+
def load_model():
|
| 37 |
+
"""Load the fine-tuned model with optimized memory usage"""
|
| 38 |
+
global model_cache, processor_cache
|
| 39 |
+
|
| 40 |
+
# Return cached model if available
|
| 41 |
+
if model_cache is not None and processor_cache is not None:
|
| 42 |
+
logger.info("Using cached model and processor")
|
| 43 |
+
return model_cache, processor_cache
|
| 44 |
+
|
| 45 |
+
# Log initial GPU state
|
| 46 |
+
log_gpu_memory("Before model loading")
|
| 47 |
+
|
| 48 |
+
# Load processor
|
| 49 |
+
logger.info(f"Loading processor from {MODEL_ID}")
|
| 50 |
+
processor = AutoProcessor.from_pretrained(MODEL_ID)
|
| 51 |
+
|
| 52 |
+
# Clean up memory
|
| 53 |
+
gc.collect()
|
| 54 |
+
if torch.cuda.is_available():
|
| 55 |
+
torch.cuda.empty_cache()
|
| 56 |
+
|
| 57 |
+
# Define proper quantization config - using 4-bit quantization
|
| 58 |
+
quant_config = BitsAndBytesConfig(
|
| 59 |
+
load_in_4bit=True,
|
| 60 |
+
bnb_4bit_compute_dtype=torch.float16, # Match training dtype
|
| 61 |
+
bnb_4bit_use_double_quant=True,
|
| 62 |
+
bnb_4bit_quant_type="nf4"
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
try:
|
| 66 |
+
logger.info("Loading model with optimized 4-bit quantization")
|
| 67 |
+
|
| 68 |
+
# Load with quantization and offloading for memory efficiency
|
| 69 |
+
model = Qwen2AudioForConditionalGeneration.from_pretrained(
|
| 70 |
+
MODEL_ID,
|
| 71 |
+
quantization_config=quant_config,
|
| 72 |
+
device_map="auto",
|
| 73 |
+
torch_dtype=torch.float16,
|
| 74 |
+
offload_folder="offload",
|
| 75 |
+
offload_state_dict=True,
|
| 76 |
+
low_cpu_mem_usage=True
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
log_gpu_memory("After optimized model loading")
|
| 80 |
+
logger.info("Model loaded successfully with optimized approach")
|
| 81 |
+
except Exception as e:
|
| 82 |
+
logger.error(f"Error loading model with optimized approach: {e}")
|
| 83 |
+
gc.collect()
|
| 84 |
+
if torch.cuda.is_available():
|
| 85 |
+
torch.cuda.empty_cache()
|
| 86 |
+
|
| 87 |
+
try:
|
| 88 |
+
# Fallback to 8-bit quantization (more stable but less compression)
|
| 89 |
+
logger.info("Attempting 8-bit quantization fallback")
|
| 90 |
+
|
| 91 |
+
from transformers import BitsAndBytesConfig
|
| 92 |
+
quant_config_8bit = BitsAndBytesConfig(
|
| 93 |
+
load_in_8bit=True,
|
| 94 |
+
llm_int8_threshold=6.0
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
model = Qwen2AudioForConditionalGeneration.from_pretrained(
|
| 98 |
+
MODEL_ID,
|
| 99 |
+
quantization_config=quant_config_8bit,
|
| 100 |
+
device_map="auto",
|
| 101 |
+
torch_dtype=torch.float16
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
log_gpu_memory("After 8-bit fallback loading")
|
| 105 |
+
logger.info("Model loaded successfully with 8-bit quantization")
|
| 106 |
+
except Exception as e2:
|
| 107 |
+
logger.error(f"Error loading with 8-bit quantization: {e2}")
|
| 108 |
+
gc.collect()
|
| 109 |
+
if torch.cuda.is_available():
|
| 110 |
+
torch.cuda.empty_cache()
|
| 111 |
+
|
| 112 |
+
try:
|
| 113 |
+
# Fallback to dummy model as last resort
|
| 114 |
+
logger.warning("Loading dummy placeholder model")
|
| 115 |
+
class DummyAudioModel:
|
| 116 |
+
def __init__(self):
|
| 117 |
+
self.device = torch.device("cpu")
|
| 118 |
+
self.dummy_parameters = [torch.tensor([0.0])]
|
| 119 |
+
|
| 120 |
+
def generate(self, **kwargs):
|
| 121 |
+
input_ids = kwargs.get("input_ids", None)
|
| 122 |
+
if input_ids is not None:
|
| 123 |
+
batch_size, seq_len = input_ids.shape
|
| 124 |
+
dummy_output = torch.ones((batch_size, seq_len + 20), dtype=torch.long)
|
| 125 |
+
dummy_output[:, :seq_len] = input_ids
|
| 126 |
+
dummy_output[:, seq_len:] = 100
|
| 127 |
+
return dummy_output
|
| 128 |
+
else:
|
| 129 |
+
return torch.ones((1, 30), dtype=torch.long)
|
| 130 |
+
|
| 131 |
+
def parameters(self):
|
| 132 |
+
return iter(self.dummy_parameters)
|
| 133 |
+
|
| 134 |
+
def to(self, device):
|
| 135 |
+
self.device = device
|
| 136 |
+
return self
|
| 137 |
+
|
| 138 |
+
model = DummyAudioModel()
|
| 139 |
+
logger.warning("Created dummy model placeholder - no real functionality available")
|
| 140 |
+
except Exception as e3:
|
| 141 |
+
logger.error(f"Failed to create dummy model: {e3}")
|
| 142 |
+
raise RuntimeError(f"Could not load any model version after multiple attempts")
|
| 143 |
+
|
| 144 |
+
# Cache the model and processor
|
| 145 |
+
model_cache = model
|
| 146 |
+
processor_cache = processor
|
| 147 |
+
|
| 148 |
+
return model, processor
|
| 149 |
|
| 150 |
+
def process_audio_from_url(audio_url, processor):
|
| 151 |
+
"""Process audio file from URL for model input with optimized memory usage"""
|
| 152 |
+
try:
|
| 153 |
+
logger.info(f"Processing audio from URL: {audio_url}")
|
| 154 |
+
# Get processor's sampling rate
|
| 155 |
+
target_sr = int(processor.feature_extractor.sampling_rate)
|
| 156 |
+
logger.info(f"Target sampling rate: {target_sr}")
|
| 157 |
+
|
| 158 |
+
# Audio bytes container
|
| 159 |
+
audio_bytes = None
|
| 160 |
+
|
| 161 |
+
# Handle various URL formats
|
| 162 |
+
if audio_url.startswith(('http://', 'https://')):
|
| 163 |
+
# For web URLs
|
| 164 |
+
try:
|
| 165 |
+
import requests
|
| 166 |
+
response = requests.get(audio_url)
|
| 167 |
+
response.raise_for_status()
|
| 168 |
+
audio_bytes = BytesIO(response.content)
|
| 169 |
+
# Free memory
|
| 170 |
+
del response
|
| 171 |
+
except Exception as req_error:
|
| 172 |
+
logger.info(f"Requests failed, falling back to urlopen: {req_error}")
|
| 173 |
+
audio_bytes = BytesIO(urlopen(audio_url).read())
|
| 174 |
+
elif audio_url.startswith('file://'):
|
| 175 |
+
# For local file URLs
|
| 176 |
+
file_path = audio_url[7:] # Remove 'file://' prefix
|
| 177 |
+
with open(file_path, 'rb') as f:
|
| 178 |
+
audio_bytes = BytesIO(f.read())
|
| 179 |
+
else:
|
| 180 |
+
# Try as a local file path
|
| 181 |
+
with open(audio_url, 'rb') as f:
|
| 182 |
+
audio_bytes = BytesIO(f.read())
|
| 183 |
+
|
| 184 |
+
# Load and resample audio
|
| 185 |
+
audio_data, sr_loaded = librosa.load(audio_bytes, sr=None)
|
| 186 |
+
logger.info(f"Audio loaded with shape: {audio_data.shape}, original SR: {sr_loaded}")
|
| 187 |
+
|
| 188 |
+
# Free memory
|
| 189 |
+
del audio_bytes
|
| 190 |
+
gc.collect()
|
| 191 |
+
|
| 192 |
+
# Resample if needed
|
| 193 |
+
if sr_loaded != target_sr:
|
| 194 |
+
logger.info(f"Resampling from {sr_loaded} Hz to {target_sr} Hz")
|
| 195 |
+
audio_data = librosa.resample(audio_data, orig_sr=sr_loaded, target_sr=target_sr)
|
| 196 |
+
|
| 197 |
+
# Reduce to 15 seconds maximum (was 30 seconds before)
|
| 198 |
+
max_seconds = 15
|
| 199 |
+
max_samples = max_seconds * target_sr
|
| 200 |
+
if len(audio_data) > max_samples:
|
| 201 |
+
logger.info(f"Limiting audio to {max_seconds} seconds for memory efficiency")
|
| 202 |
+
audio_data = audio_data[:max_samples]
|
| 203 |
+
|
| 204 |
+
# Ensure audio is float32
|
| 205 |
+
audio_data = audio_data.astype(np.float32)
|
| 206 |
+
|
| 207 |
+
return audio_data
|
| 208 |
+
except Exception as e:
|
| 209 |
+
logger.error(f"Error processing audio from URL {audio_url}: {e}", exc_info=True)
|
| 210 |
+
return None
|
| 211 |
+
finally:
|
| 212 |
+
# Clean up any lingering memory
|
| 213 |
+
if 'audio_bytes' in locals() and audio_bytes is not None:
|
| 214 |
+
del audio_bytes
|
| 215 |
+
gc.collect()
|
| 216 |
|
| 217 |
+
@spaces.GPU(duration=120)
|
| 218 |
+
def chat_with_model(audio_url, message, chat_history):
|
| 219 |
+
"""Generate response from the model using an audio URL"""
|
| 220 |
+
logger.info(f"Starting chat_with_model with audio_url: {audio_url}, message: {message}")
|
| 221 |
+
|
| 222 |
+
# Log initial memory state
|
| 223 |
+
log_gpu_memory("At start of chat_with_model")
|
| 224 |
+
|
| 225 |
+
# Validate that audio URL is provided
|
| 226 |
+
if not audio_url or not audio_url.strip():
|
| 227 |
+
return "⚠️ Please set an audio track URL first before chatting."
|
| 228 |
+
|
| 229 |
+
try:
|
| 230 |
+
# Load model and processor on demand
|
| 231 |
+
model, processor = load_model()
|
| 232 |
+
|
| 233 |
+
# Log memory after model load
|
| 234 |
+
log_gpu_memory("After model load")
|
| 235 |
+
|
| 236 |
+
# Check if we're using a dummy model
|
| 237 |
+
is_dummy = hasattr(model, '__class__') and model.__class__.__name__ == 'DummyAudioModel'
|
| 238 |
+
|
| 239 |
+
if is_dummy:
|
| 240 |
+
logger.warning("Using dummy model - providing generic response")
|
| 241 |
+
return (
|
| 242 |
+
"⚠️ I'm currently having trouble analyzing your audio due to technical limitations "
|
| 243 |
+
"in this environment. The model requires more GPU memory than is available. "
|
| 244 |
+
"Please try a different audio file or contact the developer for assistance."
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
# Process audio
|
| 248 |
+
audios = []
|
| 249 |
+
audio_data = process_audio_from_url(audio_url, processor)
|
| 250 |
+
if audio_data is not None:
|
| 251 |
+
audios.append(audio_data)
|
| 252 |
+
else:
|
| 253 |
+
return "⚠️ Failed to process audio from the provided URL. Please check that the URL is valid and accessible."
|
| 254 |
+
|
| 255 |
+
# Log memory after audio processing
|
| 256 |
+
log_gpu_memory("After audio processing")
|
| 257 |
+
|
| 258 |
+
# System prompt for the model
|
| 259 |
+
SYSTEM_PROMPT = "You are an expert audio engineer assisting with music production and mixing. Provide clear, specific advice on audio engineering techniques, mixing adjustments, and production decisions based on the audio samples and the user's questions. Focus on practical, actionable guidance. Be as specific as possible when answering the user's questions about the mix."
|
| 260 |
+
|
| 261 |
+
# Start with system prompt
|
| 262 |
+
conversation = [
|
| 263 |
+
{"role": "system", "content": SYSTEM_PROMPT}
|
| 264 |
+
]
|
| 265 |
+
|
| 266 |
+
# Add chat history - limited to last 5 exchanges to save memory
|
| 267 |
+
history_limit = min(len(chat_history), 5)
|
| 268 |
+
for user_msg, bot_msg in chat_history[-history_limit:]:
|
| 269 |
+
if user_msg:
|
| 270 |
+
conversation.append({"role": "user", "content": user_msg})
|
| 271 |
+
if bot_msg:
|
| 272 |
+
conversation.append({"role": "assistant", "content": bot_msg})
|
| 273 |
+
|
| 274 |
+
# Determine if this is the first message with a new audio
|
| 275 |
+
is_first_message_with_audio = len(chat_history) == 0
|
| 276 |
+
|
| 277 |
+
# Format user message based on whether it's the first message with audio
|
| 278 |
+
if is_first_message_with_audio:
|
| 279 |
+
# First message includes audio
|
| 280 |
+
logger.info("First message with audio, including audio in content")
|
| 281 |
+
conversation.append({
|
| 282 |
+
"role": "user",
|
| 283 |
+
"content": [
|
| 284 |
+
{"type": "audio", "audio_url": audio_url},
|
| 285 |
+
{"type": "text", "text": message}
|
| 286 |
+
]
|
| 287 |
+
})
|
| 288 |
+
else:
|
| 289 |
+
# Follow-up message about the same audio
|
| 290 |
+
logger.info("Follow-up message, including only text in content")
|
| 291 |
+
conversation.append({
|
| 292 |
+
"role": "user",
|
| 293 |
+
"content": message
|
| 294 |
+
})
|
| 295 |
+
|
| 296 |
+
# Apply chat template with error handling
|
| 297 |
+
try:
|
| 298 |
+
text = processor.apply_chat_template(
|
| 299 |
+
conversation,
|
| 300 |
+
add_generation_prompt=True,
|
| 301 |
+
tokenize=False
|
| 302 |
+
)
|
| 303 |
+
logger.info(f"Chat template applied successfully")
|
| 304 |
+
except Exception as e:
|
| 305 |
+
logger.error(f"Error applying chat template: {e}")
|
| 306 |
+
# Use a simplified approach if template fails
|
| 307 |
+
text = f"{SYSTEM_PROMPT}\n\nUser: {message}\n\nAssistant:"
|
| 308 |
+
|
| 309 |
+
# Generate model inputs
|
| 310 |
+
try:
|
| 311 |
+
inputs = processor(
|
| 312 |
+
text=text,
|
| 313 |
+
audios=audios,
|
| 314 |
+
return_tensors="pt",
|
| 315 |
+
padding=True,
|
| 316 |
+
truncation=True
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
# Move inputs to the appropriate device
|
| 320 |
+
if hasattr(model, 'device'):
|
| 321 |
+
device = model.device
|
| 322 |
+
else:
|
| 323 |
+
device = next(model.parameters()).device
|
| 324 |
+
logger.info(f"Using device: {device}")
|
| 325 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 326 |
+
|
| 327 |
+
logger.info(f"Model inputs generated")
|
| 328 |
+
log_gpu_memory("After input preparation")
|
| 329 |
+
except Exception as e:
|
| 330 |
+
logger.error(f"Error generating model inputs: {e}")
|
| 331 |
+
return f"⚠️ Error generating model inputs: {str(e)}"
|
| 332 |
+
|
| 333 |
+
# Generate response from model
|
| 334 |
+
with torch.no_grad():
|
| 335 |
+
try:
|
| 336 |
+
generate_ids = model.generate(
|
| 337 |
+
**inputs,
|
| 338 |
+
max_new_tokens=128, # Reduced from 256
|
| 339 |
+
temperature=0.7,
|
| 340 |
+
do_sample=True,
|
| 341 |
+
top_p=0.9,
|
| 342 |
+
use_cache=True # Ensure KV cache is used
|
| 343 |
+
)
|
| 344 |
+
logger.info(f"Response generated successfully")
|
| 345 |
+
log_gpu_memory("After generation")
|
| 346 |
+
except Exception as e:
|
| 347 |
+
logger.error(f"Error during model.generate: {e}")
|
| 348 |
+
return f"⚠️ Model generation error: {str(e)}"
|
| 349 |
+
|
| 350 |
+
# Decode the response
|
| 351 |
+
try:
|
| 352 |
+
generate_ids = generate_ids[:, inputs["input_ids"].size(1):]
|
| 353 |
+
response = processor.batch_decode(
|
| 354 |
+
generate_ids,
|
| 355 |
+
skip_special_tokens=True,
|
| 356 |
+
clean_up_tokenization_spaces=False
|
| 357 |
+
)[0]
|
| 358 |
+
logger.info(f"Response decoded successfully, length: {len(response)}")
|
| 359 |
+
|
| 360 |
+
# Quick validation of response
|
| 361 |
+
if not response or response.isspace():
|
| 362 |
+
logger.error("Empty response received from model")
|
| 363 |
+
return "⚠️ Model returned an empty response. Please try again."
|
| 364 |
+
|
| 365 |
+
# Clean up memory
|
| 366 |
+
del inputs, generate_ids
|
| 367 |
+
gc.collect()
|
| 368 |
+
if torch.cuda.is_available():
|
| 369 |
+
torch.cuda.empty_cache()
|
| 370 |
+
|
| 371 |
+
return response
|
| 372 |
+
except Exception as e:
|
| 373 |
+
logger.error(f"Error decoding response: {e}")
|
| 374 |
+
return f"⚠️ Error decoding response: {str(e)}"
|
| 375 |
+
|
| 376 |
+
except Exception as e:
|
| 377 |
+
logger.error(f"Unexpected error in chat_with_model: {e}", exc_info=True)
|
| 378 |
+
return f"⚠️ An error occurred: {str(e)}"
|
| 379 |
+
finally:
|
| 380 |
+
# Final memory cleanup
|
| 381 |
+
gc.collect()
|
| 382 |
+
if torch.cuda.is_available():
|
| 383 |
+
torch.cuda.empty_cache()
|
| 384 |
+
log_gpu_memory("End of chat_with_model")
|
| 385 |
|
| 386 |
+
# Function to check if URL is a valid audio file
|
| 387 |
+
def is_valid_audio_url(url):
|
| 388 |
+
if not url or not url.strip():
|
| 389 |
+
return False
|
| 390 |
+
|
| 391 |
+
url = url.strip().lower()
|
| 392 |
+
return url.endswith(('.wav', '.mp3', '.ogg', '.flac', '.m4a'))
|
| 393 |
|
| 394 |
+
# Custom theme with orange primary color and dark background
|
| 395 |
+
orange_black_theme = gr.themes.Base(
|
| 396 |
+
primary_hue="orange",
|
| 397 |
+
secondary_hue="gray",
|
| 398 |
+
neutral_hue="gray",
|
| 399 |
+
font=[gr.themes.GoogleFont("Poppins"), "ui-sans-serif", "system-ui", "sans-serif"],
|
| 400 |
+
)
|
| 401 |
|
| 402 |
+
# Custom CSS for darker theme with orange accents
|
| 403 |
+
custom_css = """
|
| 404 |
+
:root {
|
| 405 |
+
--orange-primary: #ff7700;
|
| 406 |
+
--dark-bg: #1a1a1a;
|
| 407 |
+
--darker-bg: #121212;
|
| 408 |
+
--lightest-gray: #e0e0e0;
|
| 409 |
+
}
|
| 410 |
+
|
| 411 |
+
body {
|
| 412 |
+
background-color: var(--darker-bg) !important;
|
| 413 |
+
color: var(--lightest-gray) !important;
|
| 414 |
+
font-family: 'Poppins', sans-serif !important;
|
| 415 |
+
}
|
| 416 |
+
|
| 417 |
+
.gradio-container {
|
| 418 |
+
background-color: var(--darker-bg) !important;
|
| 419 |
+
}
|
| 420 |
+
|
| 421 |
+
button.primary {
|
| 422 |
+
background-color: var(--orange-primary) !important;
|
| 423 |
+
}
|
| 424 |
+
|
| 425 |
+
.message.bot {
|
| 426 |
+
background-color: var(--dark-bg) !important;
|
| 427 |
+
}
|
| 428 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 429 |
|
| 430 |
+
# Gradio interface
|
| 431 |
+
with gr.Blocks(theme=orange_black_theme, css=custom_css) as demo:
|
| 432 |
+
gr.Markdown(
|
| 433 |
+
"""
|
| 434 |
+
# 🎧 Music Mixing Assistant
|
| 435 |
+
Enter an audio URL (.wav format recommended) and chat with your co-creative mixing agent!
|
| 436 |
+
|
| 437 |
+
Set your audio track once, then have an extended conversation about mixing and improving that specific track.
|
| 438 |
+
*(Note: Audio samples are limited to 15 seconds for optimal performance)*
|
| 439 |
+
"""
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
# Create states for chat history and audio URL
|
| 443 |
+
audio_url_state = gr.State("")
|
| 444 |
+
|
| 445 |
+
with gr.Row():
|
| 446 |
+
with gr.Column(scale=3):
|
| 447 |
+
# Chat interface with customized settings
|
| 448 |
+
chatbot = gr.Chatbot(
|
| 449 |
+
height=500,
|
| 450 |
+
avatar_images=(None, "🎧"), # Removed user icon
|
| 451 |
+
show_label=False,
|
| 452 |
+
container=True,
|
| 453 |
+
bubble_full_width=False,
|
| 454 |
+
show_copy_button=False, # Removed copy button
|
| 455 |
+
show_share_button=False, # Removed share button
|
| 456 |
+
render_markdown=True
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
# Input area
|
| 460 |
+
with gr.Row():
|
| 461 |
+
message = gr.Textbox(
|
| 462 |
+
placeholder="Ask about your mix...",
|
| 463 |
+
show_label=False,
|
| 464 |
+
container=False,
|
| 465 |
+
scale=10
|
| 466 |
+
)
|
| 467 |
+
submit_btn = gr.Button("Send", variant="primary", scale=1)
|
| 468 |
+
|
| 469 |
+
# Control buttons
|
| 470 |
+
with gr.Row():
|
| 471 |
+
clear_btn = gr.Button("Clear Chat", variant="secondary")
|
| 472 |
+
|
| 473 |
+
with gr.Column(scale=1):
|
| 474 |
+
# Audio URL input
|
| 475 |
+
audio_input = gr.Textbox(
|
| 476 |
+
label="Audio URL (.wav format)",
|
| 477 |
+
placeholder="https://example.com/your-audio-file.wav",
|
| 478 |
+
info="Enter URL to a WAV audio file - first 15 seconds will be analyzed"
|
| 479 |
+
)
|
| 480 |
+
|
| 481 |
+
# Add a button to set the URL
|
| 482 |
+
set_url_btn = gr.Button("Set Audio Track", variant="primary")
|
| 483 |
+
|
| 484 |
+
# Preview player (optional)
|
| 485 |
+
audio_preview = gr.Audio(
|
| 486 |
+
label="Audio Preview (if available)",
|
| 487 |
+
interactive=False,
|
| 488 |
+
visible=True
|
| 489 |
+
)
|
| 490 |
+
|
| 491 |
+
# Memory usage indicator
|
| 492 |
+
if torch.cuda.is_available():
|
| 493 |
+
memory_status = gr.Markdown("*GPU Memory: Initializing...*")
|
| 494 |
+
def update_memory_status():
|
| 495 |
+
if torch.cuda.is_available():
|
| 496 |
+
allocated = torch.cuda.memory_allocated() / 1024**3
|
| 497 |
+
reserved = torch.cuda.memory_reserved() / 1024**3
|
| 498 |
+
return f"*GPU Memory: {allocated:.2f}GB allocated / {reserved:.2f}GB reserved*"
|
| 499 |
+
return "*GPU Memory: Not available*"
|
| 500 |
+
else:
|
| 501 |
+
memory_status = gr.Markdown("*GPU Memory: Not available*")
|
| 502 |
+
def update_memory_status():
|
| 503 |
+
return "*GPU Memory: Not available*"
|
| 504 |
+
|
| 505 |
+
# Display status
|
| 506 |
+
status = gr.Markdown("*Status: Ready to assist with your mix!*")
|
| 507 |
+
|
| 508 |
+
# Function to update the audio URL state and preview
|
| 509 |
+
def update_audio_url(url):
|
| 510 |
+
# Basic validation
|
| 511 |
+
if not is_valid_audio_url(url):
|
| 512 |
+
return "", gr.update(value=None), "*Status: Invalid audio URL. Please use .wav, .mp3, .ogg, .flac, or .m4a format*", update_memory_status()
|
| 513 |
+
|
| 514 |
+
# Try to provide a preview if possible
|
| 515 |
+
try:
|
| 516 |
+
return url, gr.update(value=url), "*Status: Audio track set! First 15 seconds will be analyzed.*", update_memory_status()
|
| 517 |
+
except Exception as e:
|
| 518 |
+
# If preview fails, still set the URL but show warning
|
| 519 |
+
return url, gr.update(value=None), f"*Status: Audio track set, but preview failed: {str(e)}*", update_memory_status()
|
| 520 |
+
|
| 521 |
+
# Function to clear chat
|
| 522 |
+
def clear_chat():
|
| 523 |
+
return []
|
| 524 |
+
|
| 525 |
+
# Set URL button logic - Combined update and clear in one function
|
| 526 |
+
def update_and_clear_chat(url):
|
| 527 |
+
# First update the URL
|
| 528 |
+
result = update_audio_url(url)
|
| 529 |
+
# Then return the values including an empty chat
|
| 530 |
+
return result[0], result[1], [], result[2], result[3]
|
| 531 |
+
|
| 532 |
+
# Set URL button
|
| 533 |
+
set_url_btn.click(
|
| 534 |
+
update_and_clear_chat,
|
| 535 |
+
inputs=[audio_input],
|
| 536 |
+
outputs=[audio_url_state, audio_preview, chatbot, status, memory_status]
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
# Handle submit button
|
| 540 |
+
def respond(audio_url, message, chat_history):
|
| 541 |
+
if not message.strip():
|
| 542 |
+
return chat_history, "*Status: Please enter a message*", update_memory_status()
|
| 543 |
+
|
| 544 |
+
# Check if audio URL is set
|
| 545 |
+
if not audio_url or not audio_url.strip():
|
| 546 |
+
error_msg = "No audio track set. Please set an audio URL first."
|
| 547 |
+
chat_history.append((message, f"⚠️ {error_msg}"))
|
| 548 |
+
return chat_history, f"*Status: {error_msg}*", update_memory_status()
|
| 549 |
+
|
| 550 |
+
# Update chat history with user message immediately
|
| 551 |
+
chat_history.append((message, None))
|
| 552 |
+
yield chat_history, "🎵 *Analyzing your mix...*", update_memory_status()
|
| 553 |
+
|
| 554 |
+
try:
|
| 555 |
+
# Process and get response
|
| 556 |
+
bot_message = chat_with_model(audio_url, message, chat_history[:-1])
|
| 557 |
+
|
| 558 |
+
# Update the last message with the bot's response
|
| 559 |
+
chat_history[-1] = (message, bot_message)
|
| 560 |
+
|
| 561 |
+
# Return updated chat history
|
| 562 |
+
yield chat_history, "*Status: Ready to assist with your mix!*", update_memory_status()
|
| 563 |
+
except Exception as e:
|
| 564 |
+
error_msg = f"Error generating response: {str(e)}"
|
| 565 |
+
chat_history[-1] = (message, f"⚠️ {error_msg}")
|
| 566 |
+
yield chat_history, f"*Status: {error_msg}*", update_memory_status()
|
| 567 |
+
|
| 568 |
+
# Handle submit with clear input
|
| 569 |
+
def respond_and_clear_input(audio_url, message, chat_history):
|
| 570 |
+
# First get response updates
|
| 571 |
+
for result in respond(audio_url, message, chat_history):
|
| 572 |
+
# Yield each result with empty message input
|
| 573 |
+
yield result[0], result[1], result[2], ""
|
| 574 |
+
|
| 575 |
+
# Connect UI components
|
| 576 |
+
submit_btn.click(
|
| 577 |
+
respond_and_clear_input,
|
| 578 |
+
inputs=[audio_url_state, message, chatbot],
|
| 579 |
+
outputs=[chatbot, status, memory_status, message],
|
| 580 |
+
queue=True
|
| 581 |
+
)
|
| 582 |
+
|
| 583 |
+
message.submit(
|
| 584 |
+
respond_and_clear_input,
|
| 585 |
+
inputs=[audio_url_state, message, chatbot],
|
| 586 |
+
outputs=[chatbot, status, memory_status, message],
|
| 587 |
+
queue=True
|
| 588 |
+
)
|
| 589 |
+
|
| 590 |
+
# Clear button functionality to reset everything
|
| 591 |
+
def clear_all():
|
| 592 |
+
gc.collect()
|
| 593 |
+
if torch.cuda.is_available():
|
| 594 |
+
torch.cuda.empty_cache()
|
| 595 |
+
return [], "", None, "*Status: Chat cleared!*", update_memory_status(), ""
|
| 596 |
+
|
| 597 |
+
clear_btn.click(
|
| 598 |
+
clear_all,
|
| 599 |
+
None,
|
| 600 |
+
[chatbot, audio_input, audio_preview, status, memory_status, audio_url_state],
|
| 601 |
+
queue=False
|
| 602 |
+
)
|
| 603 |
|
| 604 |
+
# Launch the interface
|
| 605 |
if __name__ == "__main__":
|
| 606 |
+
# Display version warning at startup
|
| 607 |
+
try:
|
| 608 |
+
import pkg_resources
|
| 609 |
+
gradio_version = pkg_resources.get_distribution("gradio").version
|
| 610 |
+
recommended_version = "4.44.1" # Update this as needed
|
| 611 |
+
if gradio_version != recommended_version:
|
| 612 |
+
print(f"⚠️ WARNING: You are using gradio version {gradio_version}, however version {recommended_version} is available.")
|
| 613 |
+
print(f"⚠️ Please upgrade: pip install gradio=={recommended_version}")
|
| 614 |
+
except:
|
| 615 |
+
pass
|
| 616 |
+
|
| 617 |
+
# Launch with optimized settings
|
| 618 |
+
demo.launch(share=False, debug=False)
|