import os import base64 import tempfile from fastapi import FastAPI from pydantic import BaseModel from faster_whisper import WhisperModel import gradio as gr import uvicorn # <-- IMPORT THE SERVER # --- 1. Configuration & Model Loading --- os.environ["HF_HOME"] = "/tmp/huggingface_cache" os.environ["HF_HUB_CACHE"] = "/tmp/huggingface_cache" model = WhisperModel("Systran/faster-whisper-small", device="cpu", compute_type="int8") # --- 2. FastAPI Application Setup --- app = FastAPI() class AudioInput(BaseModel): data: list[str] def transcribe_audio(audio_filepath, language): if audio_filepath is None: return "Error: No audio file provided." lang = None if language == "auto" else language segments, _ = model.transcribe(audio_filepath, beam_size=5, language=lang, vad_filter=True) return " ".join(seg.text for seg in segments) # --- 3. Create the API Endpoint --- @app.post("/predict") async def predict(audio_input: AudioInput): # The Gradio API sends data in a list, so we get the first item base64_data_uri = audio_input.data[0] # Handle the null test case from curl if base64_data_uri is None: return {"data": ["Error: No audio file provided."]} header, encoded_data = base64_data_uri.split(",", 1) audio_data = base64.b64decode(encoded_data) with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio_file: temp_audio_file.write(audio_data) temp_filepath = temp_audio_file.name try: transcription = transcribe_audio(temp_filepath, "auto") finally: os.remove(temp_filepath) return {"data": [transcription]} # --- 4. Create the Gradio User Interface --- iface = gr.Interface( fn=transcribe_audio, inputs=[ gr.Audio(type="filepath", label="Upload Audio File"), gr.Radio(['en', 'bn', 'auto'], label="Select Language", value='auto') ], outputs="text", title="⚡ Zen Speech-to-Text (API Fixed)", description="Upload audio → get transcription" ) # --- 5. Mount the Gradio UI onto the FastAPI App --- app = gr.mount_gradio_app(app, iface, path="/") # --- 6. Run the Server (THIS WAS THE MISSING PART) --- if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", 7860)))