mineeuk commited on
Commit
5fb536f
·
1 Parent(s): 6e66940

fix: upgrade to gradio 5.6.0 to fix gradio_client json_schema_to_python_type bug

Browse files
Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +241 -163
  3. requirements.txt +1 -2
README.md CHANGED
@@ -5,7 +5,7 @@ emoji: 🎵
5
  colorFrom: blue
6
  colorTo: purple
7
  sdk: gradio
8
- sdk_version: "4.44.0"
9
  python_version: "3.10"
10
  app_file: app.py
11
  pinned: false
 
5
  colorFrom: blue
6
  colorTo: purple
7
  sdk: gradio
8
+ sdk_version: "5.6.0"
9
  python_version: "3.10"
10
  app_file: app.py
11
  pinned: false
app.py CHANGED
@@ -14,42 +14,51 @@ import tempfile
14
 
15
  token = os.getenv("HF_TOKEN")
16
 
 
17
  # Install madmom from GitHub
18
  def install_madmom():
19
- subprocess.check_call([
20
- sys.executable, "-m", "pip", "install",
21
- "git+https://github.com/CPJKU/madmom", "--no-cache-dir"
22
- ])
 
 
 
 
 
 
23
  print("madmom installed from GitHub")
24
 
 
25
  install_madmom()
26
 
27
  # Add current directory to Python path for ml_models
28
- sys.path.insert(0, '.')
29
- sys.path.insert(0, './ml_models')
 
30
 
31
  def download_data_from_hub():
32
  print("=== DOWNLOAD FUNCTION START ===")
33
  base_dir = Path(".")
34
  data_repo_id = "mippia/music-data"
35
-
36
  print(f"Base directory: {base_dir.absolute()}")
37
  print(f"Repository: {data_repo_id}")
38
-
39
  folders_to_check = ["covers80", "ml_models"]
40
  downloaded_folders = {}
41
-
42
  # Check LFS file
43
  lfs_file = base_dir / "1005_e_4"
44
  print(f"Checking LFS file: {lfs_file}")
45
  if lfs_file.exists():
46
- file_size = lfs_file.stat().st_size / (1024*1024)
47
  print(f"LFS file found: {file_size:.1f} MB")
48
  downloaded_folders["1005_e_4"] = str(lfs_file)
49
  else:
50
  print("LFS file not found")
51
  downloaded_folders["1005_e_4"] = None
52
-
53
  # Check existing folders
54
  print("=== CHECKING EXISTING FOLDERS ===")
55
  for folder in folders_to_check:
@@ -62,19 +71,21 @@ def download_data_from_hub():
62
  print(f" {folder} exists but is empty")
63
  else:
64
  print(f" {folder} does not exist")
65
-
66
- all_folders_exist = all((base_dir / folder).exists() and any((base_dir / folder).iterdir())
67
- for folder in folders_to_check)
 
 
68
  print(f"All folders exist: {all_folders_exist}")
69
-
70
  if not all_folders_exist:
71
  print("=== STARTING DOWNLOAD ===")
72
-
73
  # Download to a temporary directory first
74
  temp_dir = base_dir / "temp_download"
75
  print(f"Creating temp directory: {temp_dir}")
76
  temp_dir.mkdir(exist_ok=True)
77
-
78
  print("Calling snapshot_download...")
79
  downloaded_path = snapshot_download(
80
  repo_id=data_repo_id,
@@ -82,11 +93,11 @@ def download_data_from_hub():
82
  local_dir=str(temp_dir),
83
  local_dir_use_symlinks=False,
84
  token=token,
85
- ignore_patterns=["*.md", "*.txt", ".gitattributes", "README.md"]
86
  )
87
-
88
  print(f"Download completed to: {downloaded_path}")
89
-
90
  # Check what was downloaded
91
  print("=== CHECKING TEMP DOWNLOAD CONTENTS ===")
92
  print(f"Temp directory contents:")
@@ -96,31 +107,33 @@ def download_data_from_hub():
96
  if item.is_dir():
97
  file_count = len([f for f in item.rglob("*") if f.is_file()])
98
  print(f" Contains {file_count} files")
99
-
100
  # Move folders from temp to current directory
101
  print("=== MOVING FOLDERS ===")
102
  for folder_name in folders_to_check:
103
  temp_folder_path = temp_dir / folder_name
104
  target_folder_path = base_dir / folder_name
105
-
106
  print(f"Processing {folder_name}:")
107
  print(f" Source: {temp_folder_path}")
108
  print(f" Target: {target_folder_path}")
109
  print(f" Source exists: {temp_folder_path.exists()}")
110
-
111
  if temp_folder_path.exists():
112
  # Remove existing target if it exists
113
  if target_folder_path.exists():
114
  print(f" Removing existing target directory")
115
  shutil.rmtree(target_folder_path)
116
-
117
  # Move folder
118
  print(f" Moving folder...")
119
  shutil.move(str(temp_folder_path), str(target_folder_path))
120
-
121
  # Verify move
122
  if target_folder_path.exists():
123
- file_count = len([f for f in target_folder_path.rglob("*") if f.is_file()])
 
 
124
  print(f" SUCCESS: {folder_name} moved with {file_count:,} files")
125
  downloaded_folders[folder_name] = str(target_folder_path)
126
  else:
@@ -129,13 +142,13 @@ def download_data_from_hub():
129
  else:
130
  print(f" ERROR: {folder_name} not found in temp download")
131
  downloaded_folders[folder_name] = None
132
-
133
  # Clean up temp directory
134
  print("=== CLEANING UP TEMP DIRECTORY ===")
135
  if temp_dir.exists():
136
  shutil.rmtree(temp_dir)
137
  print("Temp directory removed")
138
-
139
  else:
140
  print("=== USING EXISTING FOLDERS ===")
141
  for folder_name in folders_to_check:
@@ -146,14 +159,15 @@ def download_data_from_hub():
146
  downloaded_folders[folder_name] = str(folder_path)
147
  else:
148
  downloaded_folders[folder_name] = None
149
-
150
  print("=== FINAL STATUS ===")
151
  for key, value in downloaded_folders.items():
152
  print(f"{key}: {value}")
153
-
154
  print("=== DOWNLOAD FUNCTION END ===")
155
  return downloaded_folders
156
 
 
157
  # Download data and check results
158
  print("Starting Music Plagiarism Detection App...")
159
  folders = download_data_from_hub()
@@ -179,96 +193,90 @@ if ml_models_path.exists():
179
  # Import updated inference
180
  print("=== IMPORTING INFERENCE ===")
181
 
 
182
  # Updated inference functions
183
  def inference(audio_path):
184
  from segment_transcription import segment_transcription
185
  from compare import get_one_result
186
-
187
  segment_datas = segment_transcription(audio_path)
188
  result = get_one_result(segment_datas)
189
  final_result = result_formatting(result)
190
  return final_result
191
 
 
192
  def result_formatting(result):
193
  """
194
  get_one_result에서 나온 결과를 포맷팅
195
  result: sorted list of CompareHelper objects
196
  """
197
  if not result or len(result) == 0:
198
- return {
199
- 'matches': [],
200
- 'message': 'No matches found'
201
- }
202
 
203
  # 에러 메시지 체크
204
  if isinstance(result, list) and len(result) > 0 and isinstance(result[0], str):
205
  return {
206
- 'matches': [],
207
- 'message': result[0] # "there is no note for this song"
208
  }
209
 
210
  # 상위 3개 결과 추출
211
  top_3_results = []
212
  for i, compare_helper in enumerate(result[:3]):
213
- score = compare_helper.data[0] # similarity score
214
- test_label = compare_helper.data[1] # test song info
215
- library_label = compare_helper.data[2] # matched song info
216
 
217
  # 라이브러리 레이블에서 정보 추출
218
- song_title = library_label.get('title', 'Unknown Song')
219
- library_time = library_label.get('time', 0) # 매치된 구간의 시간
220
- library_time2 = library_label.get('time2', 0)
221
 
222
  # 테스트 레이블에서 정보 추출
223
- test_time = test_label.get('time', 0) if test_label else 0 # 입력 곡의 시간
224
- test_time2 = test_label.get('time2', 0) if test_label else 0
225
 
226
  match_info = {
227
- 'rank': i + 1,
228
- 'score': float(score*100),
229
- 'song_title': song_title,
230
- 'test_time': float(test_time), # 입력 곡에서 매치된 시간
231
- 'test_time2' : float(test_time2),
232
- 'library_time': float(library_time), # 라이브러리 곡에서 매치된 시간
233
- 'library_time2': float(library_time2),
234
- 'confidence': f"{score * 100:.1f}%",
235
- 'time_match': f"Input: {test_time:.1f}s ↔ Library: {library_time:.1f}s"
236
  }
237
  top_3_results.append(match_info)
238
 
239
- return {
240
- 'matches': top_3_results,
241
- 'message': 'success'
242
- }
243
 
244
  def find_song_file_by_title(song_title):
245
  covers80_path = Path("covers80")
246
-
247
  if not covers80_path.exists():
248
  return None
249
-
250
  # Try exact match patterns
251
- exact_patterns = [
252
- f"{song_title}.mp3",
253
- f"*{song_title}.mp3",
254
- f"{song_title}*.mp3"
255
- ]
256
-
257
  for pattern in exact_patterns:
258
  matches = list(covers80_path.glob(pattern))
259
  if matches:
260
  return str(matches[0])
261
-
262
  # Try partial matches
263
- song_parts = song_title.replace('_', ' ').split()
264
  for part in song_parts:
265
  if len(part) > 3:
266
  matches = list(covers80_path.glob(f"*{part}*.mp3"))
267
  if matches:
268
  return str(matches[0])
269
-
270
  return None
271
 
 
272
  def extract_audio_segment(audio_file_path, start_time, end_time):
273
  """
274
  오디오 파일에서 특정 구간을 추출하여 임시 파일로 저장
@@ -276,109 +284,124 @@ def extract_audio_segment(audio_file_path, start_time, end_time):
276
  try:
277
  # Load audio file
278
  y, sr = librosa.load(audio_file_path, sr=None)
279
-
280
  # Convert time to samples
281
  start_sample = int(start_time * sr)
282
  end_sample = int(end_time * sr)
283
-
284
  # Extract segment
285
  segment = y[start_sample:end_sample]
286
-
287
  # Create temporary file
288
- temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.wav')
289
  temp_file.close()
290
-
291
  # Save segment
292
  import soundfile as sf
 
293
  sf.write(temp_file.name, segment, sr)
294
-
295
  return temp_file.name
296
-
297
  except Exception as e:
298
  print(f"Error extracting segment: {e}")
299
  return None
300
 
 
301
  def format_time(seconds):
302
  """Convert seconds to MM:SS format"""
303
  if seconds is None or seconds < 0:
304
  return "0:00"
305
-
306
  minutes = int(seconds // 60)
307
  seconds = int(seconds % 60)
308
  return f"{minutes}:{seconds:02d}"
309
 
 
310
  @spaces.GPU(duration=300)
311
  def process_audio_for_matching(audio_file):
312
  if audio_file is None:
313
- return [None] * 9 + ["""
 
314
  <div style='text-align: center; color: #dc2626; padding: 20px; background: #fef2f2; border-radius: 8px;'>
315
  <h3>No Audio File</h3>
316
  <p>Please upload an audio file to get started!</p>
317
  </div>
318
- """]
 
319
 
320
  result = inference(audio_file)
321
-
322
- if result.get('message') != 'success':
323
- return [None] * 9 + [f"""
 
324
  <div style="text-align: center; padding: 20px; background: #fefce8; border-radius: 8px;">
325
  <h3 style="color: #a16207;">No Matches Found</h3>
326
- <p style="color: #a16207;">{result.get('message', 'Unknown error occurred')}</p>
327
  </div>
328
- """]
329
-
330
- matches = result.get('matches', [])
 
331
  if not matches:
332
- return [None] * 9 + ["""
 
333
  <div style="text-align: center; padding: 20px; background: #fefce8; border-radius: 8px;">
334
  <h3 style="color: #a16207;">No Matches Found</h3>
335
  <p style="color: #a16207;">No matching vocals found in the dataset.</p>
336
  </div>
337
- """]
338
-
 
339
  # Initialize audio outputs
340
  audio_outputs = [None] * 9 # Reduced from 10 to 9 (removed original audio)
341
-
342
  # Get full songs and segments for top 3 matches
343
  for i, match in enumerate(matches[:3]):
344
- song_title = match.get('song_title', 'Unknown Song')
345
  song_file_path = find_song_file_by_title(song_title)
346
-
347
- print(f"Match {i+1}: {song_title}")
348
  print(f" File path: {song_file_path}")
349
-
350
  if song_file_path and os.path.exists(song_file_path):
351
  # Full matched song (indices 0, 1, 2)
352
  audio_outputs[i] = song_file_path
353
-
354
  # Extract segments for input audio (indices 3, 5, 7)
355
- input_start = match.get('test_time', 0)
356
- input_end = match.get('test_time2', input_start + 10) # Default 10 seconds if no end time
 
 
357
  input_segment = extract_audio_segment(audio_file, input_start, input_end)
358
  audio_outputs[3 + i * 2] = input_segment
359
-
360
  # Extract segments for matched song (indices 4, 6, 8)
361
- library_start = match.get('library_time', 0)
362
- library_end = match.get('library_time2', library_start + 10) # Default 10 seconds if no end time
363
- library_segment = extract_audio_segment(song_file_path, library_start, library_end)
 
 
 
 
364
  audio_outputs[4 + i * 2] = library_segment
365
-
366
  # Generate results HTML
367
  matches_html = ""
368
  for i, match in enumerate(matches[:3]):
369
- rank = match.get('rank', 0)
370
- song_title = match.get('song_title', 'Unknown Song')
371
- song_title = song_title.replace('_', ' ').replace(' temp','')
372
- score = match.get('score', 0) # Raw score instead of confidence
373
- test_time = match.get('test_time', 0)
374
- test_time2 = match.get('test_time2', 0)
375
- library_time = match.get('library_time', 0)
376
- library_time2 = match.get('library_time2', 0)
377
-
378
  # Ranking colors
379
- rank_colors = {1: '#dc2626', 2: '#ea580c', 3: '#16a34a'}
380
- rank_color = rank_colors.get(rank, '#6b7280')
381
-
382
  matches_html += f"""
383
  <div style="background: #ffffff; border-radius: 8px; padding: 15px; margin: 10px 0;
384
  border-left: 4px solid {rank_color}; box-shadow: 0 2px 8px rgba(0,0,0,0.1);">
@@ -415,7 +438,7 @@ def process_audio_for_matching(audio_file):
415
  </div>
416
  </div>
417
  """
418
-
419
  results_html = f"""
420
  <div style="background: #ffffff; border-radius: 12px; padding: 20px;
421
  box-shadow: 0 4px 15px rgba(0,0,0,0.08); border: 1px solid #e5e7eb;">
@@ -428,9 +451,10 @@ def process_audio_for_matching(audio_file):
428
  {matches_html}
429
  </div>
430
  """
431
-
432
  return audio_outputs + [results_html]
433
 
 
434
  # CSS styles
435
  custom_css = """
436
  .gradio-container {
@@ -464,9 +488,11 @@ custom_css = """
464
  """
465
 
466
  # Gradio interface
467
- with gr.Blocks(css=custom_css, theme=gr.themes.Soft(), title="Music Plagiarism Detection") as demo:
468
-
469
- gr.Markdown("""
 
 
470
  <div style="text-align: center; margin-bottom: 20px;">
471
  <h1 style="color: #111827; font-size: 2.2em; margin-bottom: 10px;">Segment-level Detection Demo</h1>
472
  <p><strong>Music Plagiarism Detection: Problem Formulation and a Segment-based Solution</strong></p>
@@ -482,87 +508,139 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft(), title="Music Plagiarism D
482
  </p>
483
  <p style="color: #dc2626; font-weight: 600;">Processing can take up to 2 minutes per file</p>
484
  </div>
485
- """, elem_classes=["main-container"])
486
-
 
 
487
  # Input section
488
  with gr.Row():
489
- audio_input = gr.Audio(type="filepath", label="Upload Your Audio File", elem_id="audio_input")
490
-
 
 
491
  with gr.Row():
492
  submit_btn = gr.Button("Analyze Audio", variant="primary", size="lg")
493
-
494
  # Output section
495
  with gr.Row():
496
  # Left column - Full Songs
497
  with gr.Column(scale=2):
498
  gr.Markdown("### 🎵 Matched Songs", elem_classes=["audio-section"])
499
-
500
  with gr.Row():
501
- match1_full = gr.Audio(label="Match #1 - Full Song", show_label=True, elem_id="match1_full")
502
- match2_full = gr.Audio(label="Match #2 - Full Song", show_label=True, elem_id="match2_full")
503
- match3_full = gr.Audio(label="Match #3 - Full Song", show_label=True, elem_id="match3_full")
504
-
 
 
 
 
 
 
505
  # Right column - Results
506
  with gr.Column(scale=1):
507
  results = gr.HTML(label="Analysis Results")
508
-
509
  # Segments section
510
  with gr.Row():
511
  with gr.Column():
512
- gr.Markdown("### 🎯 Matched Segments Comparison", elem_classes=["audio-section"])
513
-
 
 
514
  # Match 1 segments
515
  with gr.Row():
516
  with gr.Column():
517
- gr.Markdown("**Match #1 - Your Segment**", elem_classes=["segment-container"])
518
- match1_input_segment = gr.Audio(label="Your Audio Segment", show_label=False, elem_id="match1_input_seg")
 
 
 
 
 
 
 
519
  with gr.Column():
520
- gr.Markdown("**Match #1 - Matched Segment**", elem_classes=["segment-container"])
521
- match1_library_segment = gr.Audio(label="Library Segment", show_label=False, elem_id="match1_lib_seg")
522
-
 
 
 
 
 
 
 
523
  # Match 2 segments
524
  with gr.Row():
525
  with gr.Column():
526
- gr.Markdown("**Match #2 - Your Segment**", elem_classes=["segment-container"])
527
- match2_input_segment = gr.Audio(label="Your Audio Segment", show_label=False, elem_id="match2_input_seg")
 
 
 
 
 
 
 
528
  with gr.Column():
529
- gr.Markdown("**Match #2 - Matched Segment**", elem_classes=["segment-container"])
530
- match2_library_segment = gr.Audio(label="Library Segment", show_label=False, elem_id="match2_lib_seg")
531
-
 
 
 
 
 
 
 
532
  # Match 3 segments
533
  with gr.Row():
534
  with gr.Column():
535
- gr.Markdown("**Match #3 - Your Segment**", elem_classes=["segment-container"])
536
- match3_input_segment = gr.Audio(label="Your Audio Segment", show_label=False, elem_id="match3_input_seg")
 
 
 
 
 
 
 
537
  with gr.Column():
538
- gr.Markdown("**Match #3 - Matched Segment**", elem_classes=["segment-container"])
539
- match3_library_segment = gr.Audio(label="Library Segment", show_label=False, elem_id="match3_lib_seg")
540
-
 
 
 
 
 
 
 
541
  # Define outputs list
542
  outputs = [
543
- match1_full, # 0
544
- match2_full, # 1
545
- match3_full, # 2
546
- match1_input_segment, # 3
547
  match1_library_segment, # 4
548
- match2_input_segment, # 5
549
  match2_library_segment, # 6
550
- match3_input_segment, # 7
551
  match3_library_segment, # 8
552
- results # 9
553
  ]
554
-
555
  submit_btn.click(
556
- fn=process_audio_for_matching,
557
- inputs=[audio_input],
558
- outputs=outputs
559
  )
560
 
561
  if __name__ == "__main__":
562
  demo.launch(
563
- server_name="0.0.0.0",
564
- server_port=7860,
565
- show_api=False,
566
  show_error=True,
567
- share=False
568
- )
 
 
14
 
15
  token = os.getenv("HF_TOKEN")
16
 
17
+
18
  # Install madmom from GitHub
19
  def install_madmom():
20
+ subprocess.check_call(
21
+ [
22
+ sys.executable,
23
+ "-m",
24
+ "pip",
25
+ "install",
26
+ "git+https://github.com/CPJKU/madmom",
27
+ "--no-cache-dir",
28
+ ]
29
+ )
30
  print("madmom installed from GitHub")
31
 
32
+
33
  install_madmom()
34
 
35
  # Add current directory to Python path for ml_models
36
+ sys.path.insert(0, ".")
37
+ sys.path.insert(0, "./ml_models")
38
+
39
 
40
  def download_data_from_hub():
41
  print("=== DOWNLOAD FUNCTION START ===")
42
  base_dir = Path(".")
43
  data_repo_id = "mippia/music-data"
44
+
45
  print(f"Base directory: {base_dir.absolute()}")
46
  print(f"Repository: {data_repo_id}")
47
+
48
  folders_to_check = ["covers80", "ml_models"]
49
  downloaded_folders = {}
50
+
51
  # Check LFS file
52
  lfs_file = base_dir / "1005_e_4"
53
  print(f"Checking LFS file: {lfs_file}")
54
  if lfs_file.exists():
55
+ file_size = lfs_file.stat().st_size / (1024 * 1024)
56
  print(f"LFS file found: {file_size:.1f} MB")
57
  downloaded_folders["1005_e_4"] = str(lfs_file)
58
  else:
59
  print("LFS file not found")
60
  downloaded_folders["1005_e_4"] = None
61
+
62
  # Check existing folders
63
  print("=== CHECKING EXISTING FOLDERS ===")
64
  for folder in folders_to_check:
 
71
  print(f" {folder} exists but is empty")
72
  else:
73
  print(f" {folder} does not exist")
74
+
75
+ all_folders_exist = all(
76
+ (base_dir / folder).exists() and any((base_dir / folder).iterdir())
77
+ for folder in folders_to_check
78
+ )
79
  print(f"All folders exist: {all_folders_exist}")
80
+
81
  if not all_folders_exist:
82
  print("=== STARTING DOWNLOAD ===")
83
+
84
  # Download to a temporary directory first
85
  temp_dir = base_dir / "temp_download"
86
  print(f"Creating temp directory: {temp_dir}")
87
  temp_dir.mkdir(exist_ok=True)
88
+
89
  print("Calling snapshot_download...")
90
  downloaded_path = snapshot_download(
91
  repo_id=data_repo_id,
 
93
  local_dir=str(temp_dir),
94
  local_dir_use_symlinks=False,
95
  token=token,
96
+ ignore_patterns=["*.md", "*.txt", ".gitattributes", "README.md"],
97
  )
98
+
99
  print(f"Download completed to: {downloaded_path}")
100
+
101
  # Check what was downloaded
102
  print("=== CHECKING TEMP DOWNLOAD CONTENTS ===")
103
  print(f"Temp directory contents:")
 
107
  if item.is_dir():
108
  file_count = len([f for f in item.rglob("*") if f.is_file()])
109
  print(f" Contains {file_count} files")
110
+
111
  # Move folders from temp to current directory
112
  print("=== MOVING FOLDERS ===")
113
  for folder_name in folders_to_check:
114
  temp_folder_path = temp_dir / folder_name
115
  target_folder_path = base_dir / folder_name
116
+
117
  print(f"Processing {folder_name}:")
118
  print(f" Source: {temp_folder_path}")
119
  print(f" Target: {target_folder_path}")
120
  print(f" Source exists: {temp_folder_path.exists()}")
121
+
122
  if temp_folder_path.exists():
123
  # Remove existing target if it exists
124
  if target_folder_path.exists():
125
  print(f" Removing existing target directory")
126
  shutil.rmtree(target_folder_path)
127
+
128
  # Move folder
129
  print(f" Moving folder...")
130
  shutil.move(str(temp_folder_path), str(target_folder_path))
131
+
132
  # Verify move
133
  if target_folder_path.exists():
134
+ file_count = len(
135
+ [f for f in target_folder_path.rglob("*") if f.is_file()]
136
+ )
137
  print(f" SUCCESS: {folder_name} moved with {file_count:,} files")
138
  downloaded_folders[folder_name] = str(target_folder_path)
139
  else:
 
142
  else:
143
  print(f" ERROR: {folder_name} not found in temp download")
144
  downloaded_folders[folder_name] = None
145
+
146
  # Clean up temp directory
147
  print("=== CLEANING UP TEMP DIRECTORY ===")
148
  if temp_dir.exists():
149
  shutil.rmtree(temp_dir)
150
  print("Temp directory removed")
151
+
152
  else:
153
  print("=== USING EXISTING FOLDERS ===")
154
  for folder_name in folders_to_check:
 
159
  downloaded_folders[folder_name] = str(folder_path)
160
  else:
161
  downloaded_folders[folder_name] = None
162
+
163
  print("=== FINAL STATUS ===")
164
  for key, value in downloaded_folders.items():
165
  print(f"{key}: {value}")
166
+
167
  print("=== DOWNLOAD FUNCTION END ===")
168
  return downloaded_folders
169
 
170
+
171
  # Download data and check results
172
  print("Starting Music Plagiarism Detection App...")
173
  folders = download_data_from_hub()
 
193
  # Import updated inference
194
  print("=== IMPORTING INFERENCE ===")
195
 
196
+
197
  # Updated inference functions
198
  def inference(audio_path):
199
  from segment_transcription import segment_transcription
200
  from compare import get_one_result
201
+
202
  segment_datas = segment_transcription(audio_path)
203
  result = get_one_result(segment_datas)
204
  final_result = result_formatting(result)
205
  return final_result
206
 
207
+
208
  def result_formatting(result):
209
  """
210
  get_one_result에서 나온 결과를 포맷팅
211
  result: sorted list of CompareHelper objects
212
  """
213
  if not result or len(result) == 0:
214
+ return {"matches": [], "message": "No matches found"}
 
 
 
215
 
216
  # 에러 메시지 체크
217
  if isinstance(result, list) and len(result) > 0 and isinstance(result[0], str):
218
  return {
219
+ "matches": [],
220
+ "message": result[0], # "there is no note for this song"
221
  }
222
 
223
  # 상위 3개 결과 추출
224
  top_3_results = []
225
  for i, compare_helper in enumerate(result[:3]):
226
+ score = compare_helper.data[0] # similarity score
227
+ test_label = compare_helper.data[1] # test song info
228
+ library_label = compare_helper.data[2] # matched song info
229
 
230
  # 라이브러리 레이블에서 정보 추출
231
+ song_title = library_label.get("title", "Unknown Song")
232
+ library_time = library_label.get("time", 0) # 매치된 구간의 시간
233
+ library_time2 = library_label.get("time2", 0)
234
 
235
  # 테스트 레이블에서 정보 추출
236
+ test_time = test_label.get("time", 0) if test_label else 0 # 입력 곡의 시간
237
+ test_time2 = test_label.get("time2", 0) if test_label else 0
238
 
239
  match_info = {
240
+ "rank": i + 1,
241
+ "score": float(score * 100),
242
+ "song_title": song_title,
243
+ "test_time": float(test_time), # 입력 곡에서 매치된 시간
244
+ "test_time2": float(test_time2),
245
+ "library_time": float(library_time), # 라이브러리 곡에서 매치된 시간
246
+ "library_time2": float(library_time2),
247
+ "confidence": f"{score * 100:.1f}%",
248
+ "time_match": f"Input: {test_time:.1f}s ↔ Library: {library_time:.1f}s",
249
  }
250
  top_3_results.append(match_info)
251
 
252
+ return {"matches": top_3_results, "message": "success"}
253
+
 
 
254
 
255
  def find_song_file_by_title(song_title):
256
  covers80_path = Path("covers80")
257
+
258
  if not covers80_path.exists():
259
  return None
260
+
261
  # Try exact match patterns
262
+ exact_patterns = [f"{song_title}.mp3", f"*{song_title}.mp3", f"{song_title}*.mp3"]
263
+
 
 
 
 
264
  for pattern in exact_patterns:
265
  matches = list(covers80_path.glob(pattern))
266
  if matches:
267
  return str(matches[0])
268
+
269
  # Try partial matches
270
+ song_parts = song_title.replace("_", " ").split()
271
  for part in song_parts:
272
  if len(part) > 3:
273
  matches = list(covers80_path.glob(f"*{part}*.mp3"))
274
  if matches:
275
  return str(matches[0])
276
+
277
  return None
278
 
279
+
280
  def extract_audio_segment(audio_file_path, start_time, end_time):
281
  """
282
  오디오 파일에서 특정 구간을 추출하여 임시 파일로 저장
 
284
  try:
285
  # Load audio file
286
  y, sr = librosa.load(audio_file_path, sr=None)
287
+
288
  # Convert time to samples
289
  start_sample = int(start_time * sr)
290
  end_sample = int(end_time * sr)
291
+
292
  # Extract segment
293
  segment = y[start_sample:end_sample]
294
+
295
  # Create temporary file
296
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
297
  temp_file.close()
298
+
299
  # Save segment
300
  import soundfile as sf
301
+
302
  sf.write(temp_file.name, segment, sr)
303
+
304
  return temp_file.name
305
+
306
  except Exception as e:
307
  print(f"Error extracting segment: {e}")
308
  return None
309
 
310
+
311
  def format_time(seconds):
312
  """Convert seconds to MM:SS format"""
313
  if seconds is None or seconds < 0:
314
  return "0:00"
315
+
316
  minutes = int(seconds // 60)
317
  seconds = int(seconds % 60)
318
  return f"{minutes}:{seconds:02d}"
319
 
320
+
321
  @spaces.GPU(duration=300)
322
  def process_audio_for_matching(audio_file):
323
  if audio_file is None:
324
+ return [None] * 9 + [
325
+ """
326
  <div style='text-align: center; color: #dc2626; padding: 20px; background: #fef2f2; border-radius: 8px;'>
327
  <h3>No Audio File</h3>
328
  <p>Please upload an audio file to get started!</p>
329
  </div>
330
+ """
331
+ ]
332
 
333
  result = inference(audio_file)
334
+
335
+ if result.get("message") != "success":
336
+ return [None] * 9 + [
337
+ f"""
338
  <div style="text-align: center; padding: 20px; background: #fefce8; border-radius: 8px;">
339
  <h3 style="color: #a16207;">No Matches Found</h3>
340
+ <p style="color: #a16207;">{result.get("message", "Unknown error occurred")}</p>
341
  </div>
342
+ """
343
+ ]
344
+
345
+ matches = result.get("matches", [])
346
  if not matches:
347
+ return [None] * 9 + [
348
+ """
349
  <div style="text-align: center; padding: 20px; background: #fefce8; border-radius: 8px;">
350
  <h3 style="color: #a16207;">No Matches Found</h3>
351
  <p style="color: #a16207;">No matching vocals found in the dataset.</p>
352
  </div>
353
+ """
354
+ ]
355
+
356
  # Initialize audio outputs
357
  audio_outputs = [None] * 9 # Reduced from 10 to 9 (removed original audio)
358
+
359
  # Get full songs and segments for top 3 matches
360
  for i, match in enumerate(matches[:3]):
361
+ song_title = match.get("song_title", "Unknown Song")
362
  song_file_path = find_song_file_by_title(song_title)
363
+
364
+ print(f"Match {i + 1}: {song_title}")
365
  print(f" File path: {song_file_path}")
366
+
367
  if song_file_path and os.path.exists(song_file_path):
368
  # Full matched song (indices 0, 1, 2)
369
  audio_outputs[i] = song_file_path
370
+
371
  # Extract segments for input audio (indices 3, 5, 7)
372
+ input_start = match.get("test_time", 0)
373
+ input_end = match.get(
374
+ "test_time2", input_start + 10
375
+ ) # Default 10 seconds if no end time
376
  input_segment = extract_audio_segment(audio_file, input_start, input_end)
377
  audio_outputs[3 + i * 2] = input_segment
378
+
379
  # Extract segments for matched song (indices 4, 6, 8)
380
+ library_start = match.get("library_time", 0)
381
+ library_end = match.get(
382
+ "library_time2", library_start + 10
383
+ ) # Default 10 seconds if no end time
384
+ library_segment = extract_audio_segment(
385
+ song_file_path, library_start, library_end
386
+ )
387
  audio_outputs[4 + i * 2] = library_segment
388
+
389
  # Generate results HTML
390
  matches_html = ""
391
  for i, match in enumerate(matches[:3]):
392
+ rank = match.get("rank", 0)
393
+ song_title = match.get("song_title", "Unknown Song")
394
+ song_title = song_title.replace("_", " ").replace(" temp", "")
395
+ score = match.get("score", 0) # Raw score instead of confidence
396
+ test_time = match.get("test_time", 0)
397
+ test_time2 = match.get("test_time2", 0)
398
+ library_time = match.get("library_time", 0)
399
+ library_time2 = match.get("library_time2", 0)
400
+
401
  # Ranking colors
402
+ rank_colors = {1: "#dc2626", 2: "#ea580c", 3: "#16a34a"}
403
+ rank_color = rank_colors.get(rank, "#6b7280")
404
+
405
  matches_html += f"""
406
  <div style="background: #ffffff; border-radius: 8px; padding: 15px; margin: 10px 0;
407
  border-left: 4px solid {rank_color}; box-shadow: 0 2px 8px rgba(0,0,0,0.1);">
 
438
  </div>
439
  </div>
440
  """
441
+
442
  results_html = f"""
443
  <div style="background: #ffffff; border-radius: 12px; padding: 20px;
444
  box-shadow: 0 4px 15px rgba(0,0,0,0.08); border: 1px solid #e5e7eb;">
 
451
  {matches_html}
452
  </div>
453
  """
454
+
455
  return audio_outputs + [results_html]
456
 
457
+
458
  # CSS styles
459
  custom_css = """
460
  .gradio-container {
 
488
  """
489
 
490
  # Gradio interface
491
+ with gr.Blocks(
492
+ css=custom_css, theme=gr.themes.Soft(), title="Music Plagiarism Detection"
493
+ ) as demo:
494
+ gr.Markdown(
495
+ """
496
  <div style="text-align: center; margin-bottom: 20px;">
497
  <h1 style="color: #111827; font-size: 2.2em; margin-bottom: 10px;">Segment-level Detection Demo</h1>
498
  <p><strong>Music Plagiarism Detection: Problem Formulation and a Segment-based Solution</strong></p>
 
508
  </p>
509
  <p style="color: #dc2626; font-weight: 600;">Processing can take up to 2 minutes per file</p>
510
  </div>
511
+ """,
512
+ elem_classes=["main-container"],
513
+ )
514
+
515
  # Input section
516
  with gr.Row():
517
+ audio_input = gr.Audio(
518
+ type="filepath", label="Upload Your Audio File", elem_id="audio_input"
519
+ )
520
+
521
  with gr.Row():
522
  submit_btn = gr.Button("Analyze Audio", variant="primary", size="lg")
523
+
524
  # Output section
525
  with gr.Row():
526
  # Left column - Full Songs
527
  with gr.Column(scale=2):
528
  gr.Markdown("### 🎵 Matched Songs", elem_classes=["audio-section"])
529
+
530
  with gr.Row():
531
+ match1_full = gr.Audio(
532
+ label="Match #1 - Full Song", show_label=True, elem_id="match1_full"
533
+ )
534
+ match2_full = gr.Audio(
535
+ label="Match #2 - Full Song", show_label=True, elem_id="match2_full"
536
+ )
537
+ match3_full = gr.Audio(
538
+ label="Match #3 - Full Song", show_label=True, elem_id="match3_full"
539
+ )
540
+
541
  # Right column - Results
542
  with gr.Column(scale=1):
543
  results = gr.HTML(label="Analysis Results")
544
+
545
  # Segments section
546
  with gr.Row():
547
  with gr.Column():
548
+ gr.Markdown(
549
+ "### 🎯 Matched Segments Comparison", elem_classes=["audio-section"]
550
+ )
551
+
552
  # Match 1 segments
553
  with gr.Row():
554
  with gr.Column():
555
+ gr.Markdown(
556
+ "**Match #1 - Your Segment**",
557
+ elem_classes=["segment-container"],
558
+ )
559
+ match1_input_segment = gr.Audio(
560
+ label="Your Audio Segment",
561
+ show_label=False,
562
+ elem_id="match1_input_seg",
563
+ )
564
  with gr.Column():
565
+ gr.Markdown(
566
+ "**Match #1 - Matched Segment**",
567
+ elem_classes=["segment-container"],
568
+ )
569
+ match1_library_segment = gr.Audio(
570
+ label="Library Segment",
571
+ show_label=False,
572
+ elem_id="match1_lib_seg",
573
+ )
574
+
575
  # Match 2 segments
576
  with gr.Row():
577
  with gr.Column():
578
+ gr.Markdown(
579
+ "**Match #2 - Your Segment**",
580
+ elem_classes=["segment-container"],
581
+ )
582
+ match2_input_segment = gr.Audio(
583
+ label="Your Audio Segment",
584
+ show_label=False,
585
+ elem_id="match2_input_seg",
586
+ )
587
  with gr.Column():
588
+ gr.Markdown(
589
+ "**Match #2 - Matched Segment**",
590
+ elem_classes=["segment-container"],
591
+ )
592
+ match2_library_segment = gr.Audio(
593
+ label="Library Segment",
594
+ show_label=False,
595
+ elem_id="match2_lib_seg",
596
+ )
597
+
598
  # Match 3 segments
599
  with gr.Row():
600
  with gr.Column():
601
+ gr.Markdown(
602
+ "**Match #3 - Your Segment**",
603
+ elem_classes=["segment-container"],
604
+ )
605
+ match3_input_segment = gr.Audio(
606
+ label="Your Audio Segment",
607
+ show_label=False,
608
+ elem_id="match3_input_seg",
609
+ )
610
  with gr.Column():
611
+ gr.Markdown(
612
+ "**Match #3 - Matched Segment**",
613
+ elem_classes=["segment-container"],
614
+ )
615
+ match3_library_segment = gr.Audio(
616
+ label="Library Segment",
617
+ show_label=False,
618
+ elem_id="match3_lib_seg",
619
+ )
620
+
621
  # Define outputs list
622
  outputs = [
623
+ match1_full, # 0
624
+ match2_full, # 1
625
+ match3_full, # 2
626
+ match1_input_segment, # 3
627
  match1_library_segment, # 4
628
+ match2_input_segment, # 5
629
  match2_library_segment, # 6
630
+ match3_input_segment, # 7
631
  match3_library_segment, # 8
632
+ results, # 9
633
  ]
634
+
635
  submit_btn.click(
636
+ fn=process_audio_for_matching, inputs=[audio_input], outputs=outputs
 
 
637
  )
638
 
639
  if __name__ == "__main__":
640
  demo.launch(
641
+ server_name="0.0.0.0",
642
+ server_port=7860,
 
643
  show_error=True,
644
+ share=False,
645
+ ssr_mode=False,
646
+ )
requirements.txt CHANGED
@@ -13,8 +13,7 @@ pretty_midi
13
  timm
14
  einops
15
  omegaconf
16
- huggingface_hub>=0.21.0,<0.25.0
17
- gradio_client==1.3.0
18
  soxr
19
  lameenc
20
  spaces
 
13
  timm
14
  einops
15
  omegaconf
16
+ huggingface_hub>=0.25.1
 
17
  soxr
18
  lameenc
19
  spaces