speechnow / app /api.py
loveisgone's picture
Upload folder contents
fb4f813 verified
import fastapi
import shutil
import os
import zipfile
import io
import uvicorn
import glob
from typing import List
import torch
import numpy as np
import soundfile as sf
class ModelAPI:
def __init__(self, host, port):
self.host = host
self.port = port
self.base_path = os.path.join(os.path.expanduser("~"), ".modelapi")
self.noisy_audio_path = os.path.join(self.base_path, "noisy_audio")
self.enhanced_audio_path = os.path.join(self.base_path, "enhanced_audio")
# Model parameters
self.model = None
self.device = "cpu" # Force CPU since no GPU
# Create directories if they do not exist
for audio_path in [self.noisy_audio_path, self.enhanced_audio_path]:
if not os.path.exists(audio_path):
os.makedirs(audio_path)
# Clean directories
for filename in os.listdir(audio_path):
file_path = os.path.join(audio_path, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception as e:
raise e
self.app = fastapi.FastAPI()
self._setup_routes()
def _prepare(self):
"""Initialize the speech enhancement model"""
try:
# Import here to avoid loading during container build
from speechbrain.pretrained import SepformerSeparation
print(f"Loading SpeechBrain model on {self.device}...")
self.model = SepformerSeparation.from_hparams(
source="loveisgone/sepformer-wham-enhancement-no1",
savedir="pretrained_models/sepformer-wham",
run_opts={"device": self.device}
)
print("Model loaded successfully!")
except Exception as e:
print(f"Error loading SpeechBrain: {e}")
# Fallback to simple denoising
print("Using simple denoising method")
self.model = "simple"
def _enhance(self):
"""Enhance audio files"""
noisy_files = sorted(glob.glob(os.path.join(self.noisy_audio_path, "*.wav")))
for noisy_file in noisy_files:
try:
if self.model == "simple":
# Simple noise reduction
audio, sr = sf.read(noisy_file)
# Ensure mono
if len(audio.shape) > 1:
audio = np.mean(audio, axis=1)
# Simple high-pass filter to remove low-frequency noise
from scipy import signal
b, a = signal.butter(4, 100/(sr/2), "high")
enhanced = signal.filtfilt(b, a, audio)
# Normalize
max_val = np.max(np.abs(enhanced))
if max_val > 0:
enhanced = enhanced / max_val * 0.95
else:
# Use SpeechBrain model
enhanced = self.model.separate_file(path=noisy_file)
enhanced = enhanced[:, 0].cpu().numpy()
# Get sample rate
_, sr = sf.read(noisy_file)
# Save enhanced audio
output_path = os.path.join(
self.enhanced_audio_path,
os.path.basename(noisy_file)
)
sf.write(output_path, enhanced, sr)
print(f"Enhanced: {os.path.basename(noisy_file)}")
except Exception as e:
print(f"Error processing {noisy_file}: {e}")
raise e
def _setup_routes(self):
"""Setup API routes"""
self.app.get("/status/")(self.get_status)
self.app.post("/prepare/")(self.prepare)
self.app.post("/upload-audio/")(self.upload_audio)
self.app.post("/enhance/")(self.enhance_audio)
self.app.get("/download-enhanced/")(self.download_enhanced)
async def get_status(self):
try:
return {"container_running": True}
except:
raise fastapi.HTTPException(status_code=500, detail="An error occurred while fetching API status.")
async def prepare(self):
try:
self._prepare()
return {"preparations": True}
except Exception as e:
return {"preparations": False, "error": str(e)}
async def upload_audio(self, files: List[fastapi.UploadFile] = fastapi.File(...)):
uploaded_files = []
for file in files:
try:
file_path = os.path.join(self.noisy_audio_path, file.filename)
content = await file.read()
with open(file_path, "wb") as f:
f.write(content)
uploaded_files.append(file.filename)
except Exception as e:
raise fastapi.HTTPException(status_code=500, detail=f"An error occurred while uploading: {e}")
return {"uploaded_files": uploaded_files, "status": True}
async def enhance_audio(self):
try:
self._enhance()
return {"status": True}
except Exception as e:
raise fastapi.HTTPException(status_code=500, detail=f"An error occurred while enhancing: {e}")
async def download_enhanced(self):
try:
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, "w") as zip_file:
for wav_file in glob.glob(os.path.join(self.enhanced_audio_path, "*.wav")):
zip_file.write(wav_file, arcname=os.path.basename(wav_file))
zip_buffer.seek(0)
return fastapi.responses.StreamingResponse(
iter([zip_buffer.getvalue()]),
media_type="application/zip",
headers={"Content-Disposition": "attachment; filename=enhanced_audio_files.zip"}
)
except Exception as e:
raise fastapi.HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
def run(self):
uvicorn.run(self.app, host=self.host, port=self.port)