snare_scout / api.py
john221113's picture
Copy PANNs to multiple locations
94bde3d
"""
Snare Scout API - Minimal Gradio
"""
import os
import sys
import gradio as gr
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
# Download PANNs checkpoint to current directory
PANNS_CHECKPOINT = "Cnn14_mAP=0.431.pth"
if not os.path.exists(PANNS_CHECKPOINT):
print("πŸ“₯ Downloading PANNs checkpoint (~320MB)...")
import urllib.request
urllib.request.urlretrieve(
"https://zenodo.org/record/3987831/files/Cnn14_mAP%3D0.431.pth?download=1",
PANNS_CHECKPOINT
)
print(f"βœ… PANNs checkpoint ready! Size: {os.path.getsize(PANNS_CHECKPOINT)}")
else:
print(f"βœ… PANNs checkpoint exists: {os.path.getsize(PANNS_CHECKPOINT)} bytes")
# Also copy to where panns_inference might look
os.makedirs("/root/panns_data", exist_ok=True)
if not os.path.exists("/root/panns_data/Cnn14_mAP=0.431.pth"):
import shutil
shutil.copy(PANNS_CHECKPOINT, "/root/panns_data/Cnn14_mAP=0.431.pth")
print("βœ… Copied to /root/panns_data/")
DB_PATH = "library/snare_scout.sqlite"
if not os.path.exists(DB_PATH):
import requests
os.makedirs("library", exist_ok=True)
print("πŸ“₯ Downloading database...")
url = "https://www.dropbox.com/scl/fi/x6ybzdfak4m6q30k3xkec/snare_scout.sqlite?rlkey=cku45y0c7q06hf7y0ca0ktv3e&dl=1"
response = requests.get(url, stream=True)
with open(DB_PATH, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
print("βœ… Database ready!")
import scout
scout.DEFAULT_DB_PATH = DB_PATH
print("πŸ”„ Loading AI models...")
embedder = scout.get_embedder()
print("πŸ”„ Loading library...")
lib = scout.load_library_matrices(DB_PATH, False)
print(f"βœ… Ready! {len(lib['ids'])} clips loaded")
def search(audio_file):
print(f"[API] Got request: {audio_file}")
if audio_file is None:
return "No audio uploaded"
try:
with open(audio_file, 'rb') as f:
audio_bytes = f.read()
print(f"[API] Audio size: {len(audio_bytes)} bytes")
results = scout.search_library(
embedder, audio_bytes, lib,
top_k=10, debug=True, db_path=DB_PATH
)
print(f"[API] Found {len(results)} results")
lines = []
for r in results:
lines.append(f"{r['id']}|{r['url']}|{r['t0']}|{r['title']}")
return "\n".join(lines) if lines else "No matches found"
except Exception as e:
print(f"[API] ERROR: {e}")
import traceback
traceback.print_exc()
return f"Error: {e}"
demo = gr.Interface(
fn=search,
inputs=gr.Audio(type="filepath"),
outputs=gr.Textbox(),
title="Snare Scout"
)
demo.launch(server_name="0.0.0.0", server_port=7860)