john221113 commited on
Commit
c858ef3
Β·
1 Parent(s): 0e8c5bf

Switch to Gradio

Browse files
Files changed (3) hide show
  1. Dockerfile +1 -1
  2. api.py +29 -41
  3. requirements.txt +1 -3
Dockerfile CHANGED
@@ -17,4 +17,4 @@ COPY api.py ./
17
 
18
  EXPOSE 7860
19
 
20
- CMD ["uvicorn", "api:app", "--host", "0.0.0.0", "--port", "7860"]
 
17
 
18
  EXPOSE 7860
19
 
20
+ CMD ["python", "api.py"]
api.py CHANGED
@@ -1,11 +1,11 @@
1
  """
2
- Snare Scout API - Runs on Hugging Face Spaces
3
  """
4
 
5
  import os
6
  import sys
7
- import base64
8
- import traceback
9
 
10
  sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "src"))
11
 
@@ -24,21 +24,6 @@ if not os.path.exists(DB_PATH):
24
  import scout
25
  scout.DEFAULT_DB_PATH = DB_PATH
26
 
27
- from fastapi import FastAPI
28
- from fastapi.middleware.cors import CORSMiddleware
29
- from fastapi.responses import JSONResponse
30
- from pydantic import BaseModel
31
-
32
- app = FastAPI(title="Snare Scout API")
33
-
34
- app.add_middleware(
35
- CORSMiddleware,
36
- allow_origins=["*"],
37
- allow_credentials=True,
38
- allow_methods=["*"],
39
- allow_headers=["*"],
40
- )
41
-
42
  print("πŸ”„ Loading AI models...")
43
  embedder = scout.get_embedder()
44
  print("πŸ”„ Loading library...")
@@ -46,37 +31,27 @@ lib = scout.load_library_matrices(DB_PATH, False)
46
  print(f"βœ… Ready! {len(lib['ids'])} clips loaded")
47
 
48
 
49
- class SearchRequest(BaseModel):
50
- audio_base64: str
51
- top_k: int = 15
52
-
53
-
54
- @app.get("/")
55
- def root():
56
- return {"status": "ok", "clips": len(lib['ids'])}
57
-
58
-
59
- @app.post("/search")
60
- async def search(req: SearchRequest):
61
- print(f"[API] Received search request, top_k={req.top_k}")
62
 
63
  try:
64
- audio_bytes = base64.b64decode(req.audio_base64)
65
- print(f"[API] Audio size: {len(audio_bytes)} bytes")
66
 
67
- if len(audio_bytes) < 100:
68
- return JSONResponse(content={"results": []})
69
 
70
- print("[API] Starting search...")
71
  results = scout.search_library(
72
  embedder,
73
  audio_bytes,
74
  lib,
75
- top_k=req.top_k,
76
  debug=False,
77
  db_path=DB_PATH
78
  )
79
- print(f"[API] Search complete, {len(results)} results")
80
 
81
  output = {
82
  "results": [
@@ -91,9 +66,22 @@ async def search(req: SearchRequest):
91
  ]
92
  }
93
 
94
- return JSONResponse(content=output)
95
 
96
  except Exception as e:
97
  print(f"[API] ERROR: {e}")
98
- print(traceback.format_exc())
99
- return JSONResponse(content={"error": str(e), "results": []})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
+ Snare Scout API - Gradio version for HuggingFace
3
  """
4
 
5
  import os
6
  import sys
7
+ import json
8
+ import gradio as gr
9
 
10
  sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "src"))
11
 
 
24
  import scout
25
  scout.DEFAULT_DB_PATH = DB_PATH
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  print("πŸ”„ Loading AI models...")
28
  embedder = scout.get_embedder()
29
  print("πŸ”„ Loading library...")
 
31
  print(f"βœ… Ready! {len(lib['ids'])} clips loaded")
32
 
33
 
34
+ def search(audio_path, top_k=15):
35
+ print(f"[API] Search request: {audio_path}, top_k={top_k}")
36
+
37
+ if audio_path is None:
38
+ return json.dumps({"results": []})
 
 
 
 
 
 
 
 
39
 
40
  try:
41
+ with open(audio_path, 'rb') as f:
42
+ audio_bytes = f.read()
43
 
44
+ print(f"[API] Audio size: {len(audio_bytes)} bytes")
 
45
 
 
46
  results = scout.search_library(
47
  embedder,
48
  audio_bytes,
49
  lib,
50
+ top_k=int(top_k),
51
  debug=False,
52
  db_path=DB_PATH
53
  )
54
+ print(f"[API] Found {len(results)} results")
55
 
56
  output = {
57
  "results": [
 
66
  ]
67
  }
68
 
69
+ return json.dumps(output)
70
 
71
  except Exception as e:
72
  print(f"[API] ERROR: {e}")
73
+ return json.dumps({"error": str(e), "results": []})
74
+
75
+
76
+ demo = gr.Interface(
77
+ fn=search,
78
+ inputs=[
79
+ gr.Audio(type="filepath", label="Upload audio"),
80
+ gr.Number(value=15, label="Results")
81
+ ],
82
+ outputs=gr.Textbox(label="Results JSON"),
83
+ title="Snare Scout API",
84
+ api_name="search"
85
+ )
86
+
87
+ demo.launch(server_name="0.0.0.0", server_port=7860)
requirements.txt CHANGED
@@ -1,6 +1,4 @@
1
- fastapi
2
- uvicorn[standard]
3
- python-multipart
4
  numpy>=1.24.0,<2.0.0
5
  scipy>=1.12.0
6
  soundfile>=0.12.0
 
1
+ gradio
 
 
2
  numpy>=1.24.0,<2.0.0
3
  scipy>=1.12.0
4
  soundfile>=0.12.0