| import os |
| import torch |
| import logging |
| import base64 |
| from io import BytesIO |
| from gtts import gTTS |
| from fastapi import FastAPI, UploadFile, File, Form, HTTPException |
| from fastapi.middleware.cors import CORSMiddleware |
| from contextlib import asynccontextmanager |
| from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer |
| import gradio as gr |
| from fastapi.staticfiles import StaticFiles |
| from fastapi.responses import FileResponse |
|
|
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger("VeriChat-API") |
|
|
| class TranslationEngine: |
| def __init__(self): |
| self.device = "cuda:0" if torch.cuda.is_available() else "cpu" |
| self.whisper_model = "openai/whisper-tiny" |
| self.nllb_model = "facebook/nllb-200-distilled-600M" |
| self.transcriber = None |
| self.translator_model = None |
| self.translator_tokenizer = None |
|
|
| def load_models(self): |
| logger.info(f"Loading models on {self.device}...") |
| self.transcriber = pipeline("automatic-speech-recognition", model=self.whisper_model, device=self.device) |
| self.translator_tokenizer = AutoTokenizer.from_pretrained(self.nllb_model) |
| self.translator_model = AutoModelForSeq2SeqLM.from_pretrained(self.nllb_model).to(self.device) |
|
|
| def translate(self, text: str, src_lang: str, tgt_lang: str) -> str: |
| pipe = pipeline("translation", model=self.translator_model, tokenizer=self.translator_tokenizer, |
| src_lang=src_lang, tgt_lang=tgt_lang, device=self.device, max_length=400) |
| return pipe(text)[0]['translation_text'] |
|
|
| def generate_tts(self, text: str, lang_code: str) -> str: |
| """Generates a base64 audio string for the translated text.""" |
| try: |
| |
| clean_lang = lang_code.split('_')[0][:2] |
| tts = gTTS(text=text, lang=clean_lang) |
| fp = BytesIO() |
| tts.write_to_fp(fp) |
| return base64.b64encode(fp.getvalue()).decode() |
| except Exception as e: |
| logger.warning(f"TTS Failed: {e}") |
| return "" |
|
|
| engine = TranslationEngine() |
|
|
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| engine.load_models() |
| yield |
|
|
| app = FastAPI(lifespan=lifespan) |
|
|
| |
| |
| app.mount("/assets", StaticFiles(directory="static/assets"), name="assets") |
|
|
| |
| @app.get("/") |
| async def read_index(): |
| return FileResponse("static/index.html") |
|
|
| app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]) |
|
|
| @app.post("/process-speech") |
| async def process_speech(audio: UploadFile = File(...), src_lang: str = Form(...), tgt_lang: str = Form(...)): |
| temp_filename = f"temp_{audio.filename}" |
| try: |
| with open(temp_filename, "wb") as f: |
| f.write(await audio.read()) |
|
|
| |
| ts_result = engine.transcriber(temp_filename) |
| transcript = ts_result.get("text", "").strip() |
| if not transcript: raise HTTPException(status_code=400, detail="No speech detected") |
|
|
| |
| translated_text = transcript |
| if src_lang != tgt_lang: |
| translated_text = engine.translate(transcript, src_lang, tgt_lang) |
|
|
| |
| audio_base64 = engine.generate_tts(translated_text, tgt_lang) |
|
|
| return { |
| "transcript": translated_text, |
| "original_transcript": transcript, |
| "audio_payload": audio_base64 |
| } |
| finally: |
| if os.path.exists(temp_filename): os.remove(temp_filename) |
|
|
| if __name__ == "__main__": |
| import uvicorn |
| uvicorn.run(app, host="0.0.0.0", port=7860) |