darkmesh / app.py
r0kaxmin's picture
Update app.py
f4acf9c verified
import os
import base64
import tempfile
from fastapi import FastAPI
from pydantic import BaseModel
from faster_whisper import WhisperModel
import gradio as gr
import uvicorn # <-- IMPORT THE SERVER
# --- 1. Configuration & Model Loading ---
os.environ["HF_HOME"] = "/tmp/huggingface_cache"
os.environ["HF_HUB_CACHE"] = "/tmp/huggingface_cache"
model = WhisperModel("Systran/faster-whisper-small", device="cpu", compute_type="int8")
# --- 2. FastAPI Application Setup ---
app = FastAPI()
class AudioInput(BaseModel):
data: list[str]
def transcribe_audio(audio_filepath, language):
if audio_filepath is None:
return "Error: No audio file provided."
lang = None if language == "auto" else language
segments, _ = model.transcribe(audio_filepath, beam_size=5, language=lang, vad_filter=True)
return " ".join(seg.text for seg in segments)
# --- 3. Create the API Endpoint ---
@app.post("/predict")
async def predict(audio_input: AudioInput):
# The Gradio API sends data in a list, so we get the first item
base64_data_uri = audio_input.data[0]
# Handle the null test case from curl
if base64_data_uri is None:
return {"data": ["Error: No audio file provided."]}
header, encoded_data = base64_data_uri.split(",", 1)
audio_data = base64.b64decode(encoded_data)
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio_file:
temp_audio_file.write(audio_data)
temp_filepath = temp_audio_file.name
try:
transcription = transcribe_audio(temp_filepath, "auto")
finally:
os.remove(temp_filepath)
return {"data": [transcription]}
# --- 4. Create the Gradio User Interface ---
iface = gr.Interface(
fn=transcribe_audio,
inputs=[
gr.Audio(type="filepath", label="Upload Audio File"),
gr.Radio(['en', 'bn', 'auto'], label="Select Language", value='auto')
],
outputs="text",
title="⚑ Zen Speech-to-Text (API Fixed)",
description="Upload audio β†’ get transcription"
)
# --- 5. Mount the Gradio UI onto the FastAPI App ---
app = gr.mount_gradio_app(app, iface, path="/")
# --- 6. Run the Server (THIS WAS THE MISSING PART) ---
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", 7860)))