mateo496 commited on
Commit
126f215
·
1 Parent(s): e451667

Server changes for deployment

Browse files
Dockerfile CHANGED
@@ -33,6 +33,6 @@ COPY . .
33
  ENV PYTHONPATH=/app
34
  ENV PYTHONUNBUFFERED=1
35
 
36
- EXPOSE 8000
37
 
38
- CMD ["uvicorn", "src.app.server:app", "--host", "0.0.0.0", "--port", "8000"]
 
33
  ENV PYTHONPATH=/app
34
  ENV PYTHONUNBUFFERED=1
35
 
36
+ EXPOSE 7860
37
 
38
+ CMD ["uvicorn", "server:app", "--host", "0.0.0.0", "--port", "7860"]
src/app/index.html → index.html RENAMED
@@ -293,7 +293,7 @@
293
  let selectedFile = null;
294
 
295
  // API endpoint - update this to your server URL
296
- const API_URL = 'http://192.168.1.12:8082';
297
 
298
  // Click to upload
299
  uploadArea.addEventListener('click', () => fileInput.click());
 
293
  let selectedFile = null;
294
 
295
  // API endpoint - update this to your server URL
296
+ const API_URL = 'https://mateo496-esc50-model.hf.space/';
297
 
298
  // Click to upload
299
  uploadArea.addEventListener('click', () => fileInput.click());
src/app/server.py → server.py RENAMED
@@ -1,16 +1,18 @@
1
- import uvicorn
2
- import torch
3
  import tempfile
4
- import os
5
 
6
- from pydub import AudioSegment
7
- from fastapi import FastAPI, File, UploadFile, HTTPException
 
8
  from fastapi.responses import FileResponse
9
  from fastapi.staticfiles import StaticFiles
 
 
10
 
11
- from src.models.predict import load_model, predict_file
12
- from src.config.config import esc50_labels
13
 
 
14
 
15
  app = FastAPI(
16
  title="ESC50 Audio Classifier API",
@@ -18,17 +20,20 @@ app = FastAPI(
18
  version="1.0.0",
19
  )
20
 
 
 
 
 
 
 
21
 
22
-
23
- model = None
24
- device = None
25
- model_path="models/cnn/saved/final_model.pt"
26
  device = "cuda" if torch.cuda.is_available() else "cpu"
27
- model = load_model(model_path, device)
 
28
 
29
  @app.get("/")
30
  async def root():
31
- return FileResponse("src/app/index.html")
32
 
33
  @app.get("/api/status")
34
  async def status():
@@ -38,40 +43,40 @@ async def status():
38
 
39
  @app.post("/predict-top-k")
40
  async def predict_top_k(file: UploadFile = File(...), k: int = 5):
41
- if model is None:
42
  raise HTTPException(status_code=503, detail="Model not loaded")
43
-
 
 
 
 
 
 
44
  try:
45
- with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as tmp:
46
- content = await file.read()
47
- tmp.write(content)
48
- tmp_path = tmp.name
49
 
50
- tmp_wav_path = tempfile.mktemp(suffix=".wav")
51
- audio = AudioSegment.from_file(tmp_path)
52
- audio.export(tmp_wav_path, format="wav")
 
 
 
53
 
54
- predicted_class, top_probs, top_indices = predict_file(
55
- model, tmp_path, device=device, top_k=k
56
- )
57
 
58
- os.unlink(tmp_path)
59
-
60
  return {
61
- 'top_predictions': [
62
- {'class': esc50_labels[idx], 'confidence': float(prob)}
 
 
63
  for prob, idx in zip(top_probs, top_indices)
64
  ],
65
- 'predicted_class': esc50_labels[predicted_class],
66
- 'confidence': float(top_probs[0])
67
  }
68
-
69
- except Exception as e:
70
- raise HTTPException(status_code=400, detail=str(e))
71
-
72
- # uvicorn.run(
73
- # app,
74
- # host="0.0.0.0",
75
- # port=8000,
76
- # log_level="info"
77
- # )
 
1
+ import os
 
2
  import tempfile
 
3
 
4
+ import torch
5
+ import uvicorn
6
+ from fastapi import FastAPI, File, HTTPException, UploadFile
7
  from fastapi.responses import FileResponse
8
  from fastapi.staticfiles import StaticFiles
9
+ from fastapi.middleware.cors import CORSMiddleware
10
+ from pydub import AudioSegment
11
 
12
+ from src.config.config import DatasetConfig
13
+ from src.models.predict import AudioPredictor
14
 
15
+ dataset_cfg = DatasetConfig()
16
 
17
  app = FastAPI(
18
  title="ESC50 Audio Classifier API",
 
20
  version="1.0.0",
21
  )
22
 
23
+ app.add_middleware(
24
+ CORSMiddleware,
25
+ allow_origins=["*"],
26
+ allow_methods=["GET", "POST"],
27
+ allow_headers=["*"],
28
+ )
29
 
 
 
 
 
30
  device = "cuda" if torch.cuda.is_available() else "cpu"
31
+ predictor = AudioPredictor("models/cnn/saved/final_model.pt", device=device)
32
+
33
 
34
  @app.get("/")
35
  async def root():
36
+ return FileResponse("index.html")
37
 
38
  @app.get("/api/status")
39
  async def status():
 
43
 
44
  @app.post("/predict-top-k")
45
  async def predict_top_k(file: UploadFile = File(...), k: int = 5):
46
+ if predictor is None:
47
  raise HTTPException(status_code=503, detail="Model not loaded")
48
+
49
+ suffix = os.path.splitext(file.filename)[1]
50
+
51
+ with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
52
+ tmp.write(await file.read())
53
+ tmp_path = tmp.name
54
+
55
  try:
56
+ wav_path = tempfile.mktemp(suffix=".wav")
 
 
 
57
 
58
+ print("[1] Converting to wav...")
59
+ AudioSegment.from_file(tmp_path).export(wav_path, format="wav")
60
+ print("[2] Running inference...")
61
+ predicted_class, top_probs, top_indices = predictor.predict_file(wav_path, top_k=k)
62
+ print(f"[3] Done: {predicted_class} = {dataset_cfg.esc50_labels[predicted_class]}")
63
+ AudioSegment.from_file(tmp_path).export(wav_path, format="wav")
64
 
65
+ predicted_class, top_probs, top_indices = predictor.predict_file(wav_path, top_k=k)
 
 
66
 
 
 
67
  return {
68
+ "predicted_class": dataset_cfg.esc50_labels[predicted_class],
69
+ "confidence": float(top_probs[0]),
70
+ "top_predictions": [
71
+ {"class": dataset_cfg.esc50_labels[idx], "confidence": float(prob)}
72
  for prob, idx in zip(top_probs, top_indices)
73
  ],
 
 
74
  }
75
+ finally:
76
+ os.unlink(tmp_path)
77
+ if os.path.exists(wav_path):
78
+ os.unlink(wav_path)
79
+
80
+
81
+ if __name__ == "__main__":
82
+ uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info")