MPD-demo / app.py
mineeuk
fix: revert to sdk_version 5.6.0 with monkey-patch for bool schema bug
b90e79e
import spaces
# Monkey-patch gradio_client bug: bool schema not iterable
import gradio_client.utils as _gc_utils
_original_json_schema_to_python_type = _gc_utils._json_schema_to_python_type
def _patched_json_schema_to_python_type(schema, defs=None):
if isinstance(schema, bool):
return "Any"
return _original_json_schema_to_python_type(schema, defs)
_gc_utils._json_schema_to_python_type = _patched_json_schema_to_python_type
import gradio as gr
import torch
import librosa
import numpy as np
import subprocess
import sys
import os
import glob
from pathlib import Path
from huggingface_hub import snapshot_download
import shutil
import tempfile
token = os.getenv("HF_TOKEN")
# Install madmom from GitHub
def install_madmom():
subprocess.check_call(
[
sys.executable,
"-m",
"pip",
"install",
"git+https://github.com/CPJKU/madmom",
"--no-cache-dir",
]
)
print("madmom installed from GitHub")
install_madmom()
# Add current directory to Python path for ml_models
sys.path.insert(0, ".")
sys.path.insert(0, "./ml_models")
def download_data_from_hub():
print("=== DOWNLOAD FUNCTION START ===")
base_dir = Path(".")
data_repo_id = "mippia/music-data"
print(f"Base directory: {base_dir.absolute()}")
print(f"Repository: {data_repo_id}")
folders_to_check = ["covers80", "ml_models"]
downloaded_folders = {}
# Check LFS file
lfs_file = base_dir / "1005_e_4"
print(f"Checking LFS file: {lfs_file}")
if lfs_file.exists():
file_size = lfs_file.stat().st_size / (1024 * 1024)
print(f"LFS file found: {file_size:.1f} MB")
downloaded_folders["1005_e_4"] = str(lfs_file)
else:
print("LFS file not found")
downloaded_folders["1005_e_4"] = None
# Check existing folders
print("=== CHECKING EXISTING FOLDERS ===")
for folder in folders_to_check:
folder_path = base_dir / folder
print(f"Checking {folder} at {folder_path}")
if folder_path.exists():
if any(folder_path.iterdir()):
print(f" {folder} exists and has content")
else:
print(f" {folder} exists but is empty")
else:
print(f" {folder} does not exist")
all_folders_exist = all(
(base_dir / folder).exists() and any((base_dir / folder).iterdir())
for folder in folders_to_check
)
print(f"All folders exist: {all_folders_exist}")
if not all_folders_exist:
print("=== STARTING DOWNLOAD ===")
# Download to a temporary directory first
temp_dir = base_dir / "temp_download"
print(f"Creating temp directory: {temp_dir}")
temp_dir.mkdir(exist_ok=True)
print("Calling snapshot_download...")
downloaded_path = snapshot_download(
repo_id=data_repo_id,
repo_type="dataset",
local_dir=str(temp_dir),
local_dir_use_symlinks=False,
token=token,
ignore_patterns=["*.md", "*.txt", ".gitattributes", "README.md"],
)
print(f"Download completed to: {downloaded_path}")
# Check what was downloaded
print("=== CHECKING TEMP DOWNLOAD CONTENTS ===")
print(f"Temp directory contents:")
for item in temp_dir.iterdir():
item_type = "DIR" if item.is_dir() else "FILE"
print(f" {item.name} ({item_type})")
if item.is_dir():
file_count = len([f for f in item.rglob("*") if f.is_file()])
print(f" Contains {file_count} files")
# Move folders from temp to current directory
print("=== MOVING FOLDERS ===")
for folder_name in folders_to_check:
temp_folder_path = temp_dir / folder_name
target_folder_path = base_dir / folder_name
print(f"Processing {folder_name}:")
print(f" Source: {temp_folder_path}")
print(f" Target: {target_folder_path}")
print(f" Source exists: {temp_folder_path.exists()}")
if temp_folder_path.exists():
# Remove existing target if it exists
if target_folder_path.exists():
print(f" Removing existing target directory")
shutil.rmtree(target_folder_path)
# Move folder
print(f" Moving folder...")
shutil.move(str(temp_folder_path), str(target_folder_path))
# Verify move
if target_folder_path.exists():
file_count = len(
[f for f in target_folder_path.rglob("*") if f.is_file()]
)
print(f" SUCCESS: {folder_name} moved with {file_count:,} files")
downloaded_folders[folder_name] = str(target_folder_path)
else:
print(f" ERROR: Move failed for {folder_name}")
downloaded_folders[folder_name] = None
else:
print(f" ERROR: {folder_name} not found in temp download")
downloaded_folders[folder_name] = None
# Clean up temp directory
print("=== CLEANING UP TEMP DIRECTORY ===")
if temp_dir.exists():
shutil.rmtree(temp_dir)
print("Temp directory removed")
else:
print("=== USING EXISTING FOLDERS ===")
for folder_name in folders_to_check:
folder_path = base_dir / folder_name
if folder_path.exists():
file_count = len([f for f in folder_path.rglob("*") if f.is_file()])
print(f"{folder_name}: {file_count:,} files")
downloaded_folders[folder_name] = str(folder_path)
else:
downloaded_folders[folder_name] = None
print("=== FINAL STATUS ===")
for key, value in downloaded_folders.items():
print(f"{key}: {value}")
print("=== DOWNLOAD FUNCTION END ===")
return downloaded_folders
# Download data and check results
print("Starting Music Plagiarism Detection App...")
folders = download_data_from_hub()
# Final verification
print("=== FINAL VERIFICATION ===")
current_dir = Path(".")
print(f"Current directory contents after download:")
for item in current_dir.iterdir():
item_type = "DIR" if item.is_dir() else "FILE"
print(f" {item.name} ({item_type})")
# Check ml_models specifically
ml_models_path = Path("ml_models")
print(f"ml_models check:")
print(f" Exists: {ml_models_path.exists()}")
if ml_models_path.exists():
print(f" Is directory: {ml_models_path.is_dir()}")
print(f" Contents:")
for item in ml_models_path.iterdir():
print(f" {item.name}")
# Import updated inference
print("=== IMPORTING INFERENCE ===")
# Updated inference functions
def inference(audio_path):
from segment_transcription import segment_transcription
from compare import get_one_result
segment_datas = segment_transcription(audio_path)
result = get_one_result(segment_datas)
final_result = result_formatting(result)
return final_result
def result_formatting(result):
"""
get_one_result에서 나온 결과를 포맷팅
result: sorted list of CompareHelper objects
"""
if not result or len(result) == 0:
return {"matches": [], "message": "No matches found"}
# 에러 메시지 체크
if isinstance(result, list) and len(result) > 0 and isinstance(result[0], str):
return {
"matches": [],
"message": result[0], # "there is no note for this song"
}
# 상위 3개 결과 추출
top_3_results = []
for i, compare_helper in enumerate(result[:3]):
score = compare_helper.data[0] # similarity score
test_label = compare_helper.data[1] # test song info
library_label = compare_helper.data[2] # matched song info
# 라이브러리 레이블에서 정보 추출
song_title = library_label.get("title", "Unknown Song")
library_time = library_label.get("time", 0) # 매치된 구간의 시간
library_time2 = library_label.get("time2", 0)
# 테스트 레이블에서 정보 추출
test_time = test_label.get("time", 0) if test_label else 0 # 입력 곡의 시간
test_time2 = test_label.get("time2", 0) if test_label else 0
match_info = {
"rank": i + 1,
"score": float(score * 100),
"song_title": song_title,
"test_time": float(test_time), # 입력 곡에서 매치된 시간
"test_time2": float(test_time2),
"library_time": float(library_time), # 라이브러리 곡에서 매치된 시간
"library_time2": float(library_time2),
"confidence": f"{score * 100:.1f}%",
"time_match": f"Input: {test_time:.1f}s ↔ Library: {library_time:.1f}s",
}
top_3_results.append(match_info)
return {"matches": top_3_results, "message": "success"}
def find_song_file_by_title(song_title):
covers80_path = Path("covers80")
if not covers80_path.exists():
return None
# Try exact match patterns
exact_patterns = [f"{song_title}.mp3", f"*{song_title}.mp3", f"{song_title}*.mp3"]
for pattern in exact_patterns:
matches = list(covers80_path.glob(pattern))
if matches:
return str(matches[0])
# Try partial matches
song_parts = song_title.replace("_", " ").split()
for part in song_parts:
if len(part) > 3:
matches = list(covers80_path.glob(f"*{part}*.mp3"))
if matches:
return str(matches[0])
return None
def extract_audio_segment(audio_file_path, start_time, end_time):
"""
오디오 파일에서 특정 구간을 추출하여 임시 파일로 저장
"""
try:
# Load audio file
y, sr = librosa.load(audio_file_path, sr=None)
# Convert time to samples
start_sample = int(start_time * sr)
end_sample = int(end_time * sr)
# Extract segment
segment = y[start_sample:end_sample]
# Create temporary file
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
temp_file.close()
# Save segment
import soundfile as sf
sf.write(temp_file.name, segment, sr)
return temp_file.name
except Exception as e:
print(f"Error extracting segment: {e}")
return None
def format_time(seconds):
"""Convert seconds to MM:SS format"""
if seconds is None or seconds < 0:
return "0:00"
minutes = int(seconds // 60)
seconds = int(seconds % 60)
return f"{minutes}:{seconds:02d}"
@spaces.GPU(duration=300)
def process_audio_for_matching(audio_file):
if audio_file is None:
return [None] * 9 + [
"""
<div style='text-align: center; color: #dc2626; padding: 20px; background: #fef2f2; border-radius: 8px;'>
<h3>No Audio File</h3>
<p>Please upload an audio file to get started!</p>
</div>
"""
]
result = inference(audio_file)
if result.get("message") != "success":
return [None] * 9 + [
f"""
<div style="text-align: center; padding: 20px; background: #fefce8; border-radius: 8px;">
<h3 style="color: #a16207;">No Matches Found</h3>
<p style="color: #a16207;">{result.get("message", "Unknown error occurred")}</p>
</div>
"""
]
matches = result.get("matches", [])
if not matches:
return [None] * 9 + [
"""
<div style="text-align: center; padding: 20px; background: #fefce8; border-radius: 8px;">
<h3 style="color: #a16207;">No Matches Found</h3>
<p style="color: #a16207;">No matching vocals found in the dataset.</p>
</div>
"""
]
# Initialize audio outputs
audio_outputs = [None] * 9 # Reduced from 10 to 9 (removed original audio)
# Get full songs and segments for top 3 matches
for i, match in enumerate(matches[:3]):
song_title = match.get("song_title", "Unknown Song")
song_file_path = find_song_file_by_title(song_title)
print(f"Match {i + 1}: {song_title}")
print(f" File path: {song_file_path}")
if song_file_path and os.path.exists(song_file_path):
# Full matched song (indices 0, 1, 2)
audio_outputs[i] = song_file_path
# Extract segments for input audio (indices 3, 5, 7)
input_start = match.get("test_time", 0)
input_end = match.get(
"test_time2", input_start + 10
) # Default 10 seconds if no end time
input_segment = extract_audio_segment(audio_file, input_start, input_end)
audio_outputs[3 + i * 2] = input_segment
# Extract segments for matched song (indices 4, 6, 8)
library_start = match.get("library_time", 0)
library_end = match.get(
"library_time2", library_start + 10
) # Default 10 seconds if no end time
library_segment = extract_audio_segment(
song_file_path, library_start, library_end
)
audio_outputs[4 + i * 2] = library_segment
# Generate results HTML
matches_html = ""
for i, match in enumerate(matches[:3]):
rank = match.get("rank", 0)
song_title = match.get("song_title", "Unknown Song")
song_title = song_title.replace("_", " ").replace(" temp", "")
score = match.get("score", 0) # Raw score instead of confidence
test_time = match.get("test_time", 0)
test_time2 = match.get("test_time2", 0)
library_time = match.get("library_time", 0)
library_time2 = match.get("library_time2", 0)
# Ranking colors
rank_colors = {1: "#dc2626", 2: "#ea580c", 3: "#16a34a"}
rank_color = rank_colors.get(rank, "#6b7280")
matches_html += f"""
<div style="background: #ffffff; border-radius: 8px; padding: 15px; margin: 10px 0;
border-left: 4px solid {rank_color}; box-shadow: 0 2px 8px rgba(0,0,0,0.1);">
<!-- Title -->
<div style="text-align: center; margin-bottom: 15px;">
<h4 style="color: #111827; margin: 0; font-size: 1.1em; word-wrap: break-word; overflow-wrap: break-word;">
<span style="background: {rank_color}; color: white; padding: 2px 6px; border-radius: 10px; font-size: 0.8em; margin-right: 8px;">
#{rank}
</span>
{song_title}
</h4>
</div>
<!-- Stats -->
<div style="display: flex; justify-content: space-around; text-align: center;">
<div>
<small style="color: #6b7280; display: block; margin-bottom: 2px;">Your Segment</small>
<div style="color: #dc2626; font-weight: 600; font-size: 0.9em;">
{format_time(test_time)} - {format_time(test_time2)}
</div>
</div>
<div>
<small style="color: #6b7280; display: block; margin-bottom: 2px;">Matched Segment</small>
<div style="color: #16a34a; font-weight: 600; font-size: 0.9em;">
{format_time(library_time)} - {format_time(library_time2)}
</div>
</div>
<div>
<small style="color: #6b7280; display: block; margin-bottom: 2px;">Score</small>
<div style="background: #f3f4f6; color: #111827; padding: 4px 10px; border-radius: 12px; font-weight: 600; font-size: 0.9em; display: inline-block;">
{score:.1f}
</div>
</div>
</div>
</div>
"""
results_html = f"""
<div style="background: #ffffff; border-radius: 12px; padding: 20px;
box-shadow: 0 4px 15px rgba(0,0,0,0.08); border: 1px solid #e5e7eb;">
<div style="text-align: center; margin-bottom: 20px;">
<h3 style="color: #111827; margin: 0;">Vocal Matching Results</h3>
<p style="color: #6b7280; margin: 5px 0;">Found {len(matches)} similar vocals in Covers80 dataset</p>
<p style="color: #2563eb; margin: 5px 0; font-size: 0.9em;">🎵 Listen to original songs and extracted segments</p>
<p style="color: #9ca3af; margin: 5px 0; font-size: 0.85em;">💡 Scores above 50 generally indicate meaningful similarity for me haha..</p>
</div>
{matches_html}
</div>
"""
return audio_outputs + [results_html]
# CSS styles
custom_css = """
.gradio-container {
background: #f9fafb !important;
min-height: 100vh;
padding: 20px;
}
.main-container {
background: #ffffff !important;
border-radius: 16px !important;
box-shadow: 0 4px 20px rgba(0,0,0,0.08) !important;
margin: 0 auto !important;
padding: 30px !important;
max-width: 1400px;
border: 1px solid #e5e7eb !important;
}
.audio-section {
background: #f8fafc !important;
border-radius: 12px !important;
padding: 15px !important;
margin: 10px 0 !important;
border: 1px solid #e2e8f0 !important;
}
.segment-container {
background: #fefefe !important;
border-radius: 8px !important;
padding: 12px !important;
border: 1px solid #e5e7eb !important;
margin: 5px 0 !important;
}
"""
# Gradio interface
with gr.Blocks(
css=custom_css, theme=gr.themes.Soft(), title="Music Plagiarism Detection"
) as demo:
gr.Markdown(
"""
<div style="text-align: center; margin-bottom: 20px;">
<h1 style="color: #111827; font-size: 2.2em; margin-bottom: 10px;">Segment-level Detection Demo</h1>
<p><strong>Music Plagiarism Detection: Problem Formulation and a Segment-based Solution</strong></p>
<p style="font-size: 0.9em; color: #6b7280; margin: 8px 0;">
Authors: Seonghyeon Go, Yumin Kim | MIPPIA Inc. | Submitted to ICASSP 2026
</p>
<hr style="border: none; border-top: 1px solid #e5e7eb; margin: 15px 0;">
<p><strong>Demo Version Notice:</strong> This demo differs from the paper version and focuses exclusively on vocal.</p>
<p> Please use this demo for only understanding the concept of segment-level matching!</p>
<p style="font-size: 0.9em; color: #6b7280; margin: 8px 0;">
Structure analysis has been excluded for optimization. Results are derived from all downbeats,
so segment boundaries may not align perfectly with musical phrases.
</p>
<p style="color: #dc2626; font-weight: 600;">Processing can take up to 2 minutes per file</p>
</div>
""",
elem_classes=["main-container"],
)
# Input section
with gr.Row():
audio_input = gr.Audio(
type="filepath", label="Upload Your Audio File", elem_id="audio_input"
)
with gr.Row():
submit_btn = gr.Button("Analyze Audio", variant="primary", size="lg")
# Output section
with gr.Row():
# Left column - Full Songs
with gr.Column(scale=2):
gr.Markdown("### 🎵 Matched Songs", elem_classes=["audio-section"])
with gr.Row():
match1_full = gr.Audio(
label="Match #1 - Full Song", show_label=True, elem_id="match1_full"
)
match2_full = gr.Audio(
label="Match #2 - Full Song", show_label=True, elem_id="match2_full"
)
match3_full = gr.Audio(
label="Match #3 - Full Song", show_label=True, elem_id="match3_full"
)
# Right column - Results
with gr.Column(scale=1):
results = gr.HTML(label="Analysis Results")
# Segments section
with gr.Row():
with gr.Column():
gr.Markdown(
"### 🎯 Matched Segments Comparison", elem_classes=["audio-section"]
)
# Match 1 segments
with gr.Row():
with gr.Column():
gr.Markdown(
"**Match #1 - Your Segment**",
elem_classes=["segment-container"],
)
match1_input_segment = gr.Audio(
label="Your Audio Segment",
show_label=False,
elem_id="match1_input_seg",
)
with gr.Column():
gr.Markdown(
"**Match #1 - Matched Segment**",
elem_classes=["segment-container"],
)
match1_library_segment = gr.Audio(
label="Library Segment",
show_label=False,
elem_id="match1_lib_seg",
)
# Match 2 segments
with gr.Row():
with gr.Column():
gr.Markdown(
"**Match #2 - Your Segment**",
elem_classes=["segment-container"],
)
match2_input_segment = gr.Audio(
label="Your Audio Segment",
show_label=False,
elem_id="match2_input_seg",
)
with gr.Column():
gr.Markdown(
"**Match #2 - Matched Segment**",
elem_classes=["segment-container"],
)
match2_library_segment = gr.Audio(
label="Library Segment",
show_label=False,
elem_id="match2_lib_seg",
)
# Match 3 segments
with gr.Row():
with gr.Column():
gr.Markdown(
"**Match #3 - Your Segment**",
elem_classes=["segment-container"],
)
match3_input_segment = gr.Audio(
label="Your Audio Segment",
show_label=False,
elem_id="match3_input_seg",
)
with gr.Column():
gr.Markdown(
"**Match #3 - Matched Segment**",
elem_classes=["segment-container"],
)
match3_library_segment = gr.Audio(
label="Library Segment",
show_label=False,
elem_id="match3_lib_seg",
)
# Define outputs list
outputs = [
match1_full, # 0
match2_full, # 1
match3_full, # 2
match1_input_segment, # 3
match1_library_segment, # 4
match2_input_segment, # 5
match2_library_segment, # 6
match3_input_segment, # 7
match3_library_segment, # 8
results, # 9
]
submit_btn.click(
fn=process_audio_for_matching, inputs=[audio_input], outputs=outputs
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
show_api=False,
show_error=True,
ssr_mode=False,
)