speechnow / update_api.py
loveisgone's picture
Upload folder contents
fb4f813 verified
# Новое содержимое для app/api.py
api_content = '''import fastapi
import shutil
import os
import zipfile
import io
import uvicorn
import glob
from typing import List
import torch
import numpy as np
import soundfile as sf
class ModelAPI:
def __init__(self, host, port):
self.host = host
self.port = port
self.base_path = os.path.join(os.path.expanduser("~"), ".modelapi")
self.noisy_audio_path = os.path.join(self.base_path, "noisy_audio")
self.enhanced_audio_path = os.path.join(self.base_path, "enhanced_audio")
# Model parameters
self.model = None
self.device = "cpu" # Force CPU since no GPU
# Create directories if they do not exist
for audio_path in [self.noisy_audio_path, self.enhanced_audio_path]:
if not os.path.exists(audio_path):
os.makedirs(audio_path)
# Clean directories
for filename in os.listdir(audio_path):
file_path = os.path.join(audio_path, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception as e:
raise e
self.app = fastapi.FastAPI()
self._setup_routes()
def _prepare(self):
"""Initialize the speech enhancement model"""
try:
# Import here to avoid loading during container build
from speechbrain.pretrained import SepformerSeparation
print(f"Loading SpeechBrain model on {self.device}...")
self.model = SepformerSeparation.from_hparams(
source="loveisgone/sepformer-wham-enhancement-no1",
savedir="pretrained_models/sepformer-wham",
run_opts={"device": self.device}
)
print("Model loaded successfully!")
except Exception as e:
print(f"Error loading SpeechBrain: {e}")
# Fallback to simple denoising
print("Using simple denoising method")
self.model = "simple"
def _enhance(self):
"""Enhance audio files"""
noisy_files = sorted(glob.glob(os.path.join(self.noisy_audio_path, "*.wav")))
for noisy_file in noisy_files:
try:
if self.model == "simple":
# Simple noise reduction
audio, sr = sf.read(noisy_file)
# Ensure mono
if len(audio.shape) > 1:
audio = np.mean(audio, axis=1)
# Simple high-pass filter to remove low-frequency noise
from scipy import signal
b, a = signal.butter(4, 100/(sr/2), "high")
enhanced = signal.filtfilt(b, a, audio)
# Normalize
max_val = np.max(np.abs(enhanced))
if max_val > 0:
enhanced = enhanced / max_val * 0.95
else:
# Use SpeechBrain model
enhanced = self.model.separate_file(path=noisy_file)
enhanced = enhanced[:, 0].cpu().numpy()
# Get sample rate
_, sr = sf.read(noisy_file)
# Save enhanced audio
output_path = os.path.join(
self.enhanced_audio_path,
os.path.basename(noisy_file)
)
sf.write(output_path, enhanced, sr)
print(f"Enhanced: {os.path.basename(noisy_file)}")
except Exception as e:
print(f"Error processing {noisy_file}: {e}")
raise e
def _setup_routes(self):
"""Setup API routes"""
self.app.get("/status/")(self.get_status)
self.app.post("/prepare/")(self.prepare)
self.app.post("/upload-audio/")(self.upload_audio)
self.app.post("/enhance/")(self.enhance_audio)
self.app.get("/download-enhanced/")(self.download_enhanced)
async def get_status(self):
try:
return {"container_running": True}
except:
raise fastapi.HTTPException(status_code=500, detail="An error occurred while fetching API status.")
async def prepare(self):
try:
self._prepare()
return {"preparations": True}
except Exception as e:
return {"preparations": False, "error": str(e)}
async def upload_audio(self, files: List[fastapi.UploadFile] = fastapi.File(...)):
uploaded_files = []
for file in files:
try:
file_path = os.path.join(self.noisy_audio_path, file.filename)
content = await file.read()
with open(file_path, "wb") as f:
f.write(content)
uploaded_files.append(file.filename)
except Exception as e:
raise fastapi.HTTPException(status_code=500, detail=f"An error occurred while uploading: {e}")
return {"uploaded_files": uploaded_files, "status": True}
async def enhance_audio(self):
try:
self._enhance()
return {"status": True}
except Exception as e:
raise fastapi.HTTPException(status_code=500, detail=f"An error occurred while enhancing: {e}")
async def download_enhanced(self):
try:
zip_buffer = io.BytesIO()
with zipfile.ZipFile(zip_buffer, "w") as zip_file:
for wav_file in glob.glob(os.path.join(self.enhanced_audio_path, "*.wav")):
zip_file.write(wav_file, arcname=os.path.basename(wav_file))
zip_buffer.seek(0)
return fastapi.responses.StreamingResponse(
iter([zip_buffer.getvalue()]),
media_type="application/zip",
headers={"Content-Disposition": "attachment; filename=enhanced_audio_files.zip"}
)
except Exception as e:
raise fastapi.HTTPException(status_code=500, detail=f"An error occurred: {str(e)}")
def run(self):
uvicorn.run(self.app, host=self.host, port=self.port)
'''
with open('app/api.py', 'w') as f:
f.write(api_content)
print("API updated successfully!")