wazobiaa / app.py
Pyrexx611's picture
Update app.py
145e354 verified
import os
import torch
import logging
import base64
from io import BytesIO
from gtts import gTTS
from fastapi import FastAPI, UploadFile, File, Form, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from contextlib import asynccontextmanager
from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer
import gradio as gr
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("VeriChat-API")
class TranslationEngine:
def __init__(self):
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
self.whisper_model = "openai/whisper-tiny"
self.nllb_model = "facebook/nllb-200-distilled-600M"
self.transcriber = None
self.translator_model = None
self.translator_tokenizer = None
def load_models(self):
logger.info(f"Loading models on {self.device}...")
self.transcriber = pipeline("automatic-speech-recognition", model=self.whisper_model, device=self.device)
self.translator_tokenizer = AutoTokenizer.from_pretrained(self.nllb_model)
self.translator_model = AutoModelForSeq2SeqLM.from_pretrained(self.nllb_model).to(self.device)
def translate(self, text: str, src_lang: str, tgt_lang: str) -> str:
pipe = pipeline("translation", model=self.translator_model, tokenizer=self.translator_tokenizer,
src_lang=src_lang, tgt_lang=tgt_lang, device=self.device, max_length=400)
return pipe(text)[0]['translation_text']
def generate_tts(self, text: str, lang_code: str) -> str:
"""Generates a base64 audio string for the translated text."""
try:
# Strip NLLB suffix for gTTS compatibility (e.g., 'fra_Latn' -> 'fr')
clean_lang = lang_code.split('_')[0][:2]
tts = gTTS(text=text, lang=clean_lang)
fp = BytesIO()
tts.write_to_fp(fp)
return base64.b64encode(fp.getvalue()).decode()
except Exception as e:
logger.warning(f"TTS Failed: {e}")
return ""
engine = TranslationEngine()
@asynccontextmanager
async def lifespan(app: FastAPI):
engine.load_models()
yield
app = FastAPI(lifespan=lifespan)
# 1. Mount the static directory for JS/CSS assets
# This assumes your build files (assets folder) are inside /static
app.mount("/assets", StaticFiles(directory="static/assets"), name="assets")
# 2. Add a route to serve the main index.html
@app.get("/")
async def read_index():
return FileResponse("static/index.html")
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
@app.post("/process-speech")
async def process_speech(audio: UploadFile = File(...), src_lang: str = Form(...), tgt_lang: str = Form(...)):
temp_filename = f"temp_{audio.filename}"
try:
with open(temp_filename, "wb") as f:
f.write(await audio.read())
# 1. Transcribe
ts_result = engine.transcriber(temp_filename)
transcript = ts_result.get("text", "").strip()
if not transcript: raise HTTPException(status_code=400, detail="No speech detected")
# 2. Translate
translated_text = transcript
if src_lang != tgt_lang:
translated_text = engine.translate(transcript, src_lang, tgt_lang)
# 3. Generate TTS for the translation
audio_base64 = engine.generate_tts(translated_text, tgt_lang)
return {
"transcript": translated_text,
"original_transcript": transcript,
"audio_payload": audio_base64
}
finally:
if os.path.exists(temp_filename): os.remove(temp_filename)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)