Spaces:
Sleeping
Sleeping
| import os | |
| import shutil | |
| import uuid | |
| import logging | |
| from fastapi import FastAPI, UploadFile, File, Form, HTTPException, BackgroundTasks, Depends | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import FileResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from pydantic import BaseModel | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", | |
| handlers=[logging.StreamHandler()] | |
| ) | |
| logger = logging.getLogger("baif-api") | |
| try: | |
| from langdetect import detect, DetectorFactory | |
| DetectorFactory.seed = 0 | |
| except ImportError: | |
| def detect(text: str) -> str: | |
| return "en" | |
| import sys | |
| # Ensure the backend directory is in the path for absolute imports | |
| backend_dir = os.path.dirname(os.path.abspath(__file__)) | |
| if backend_dir not in sys.path: | |
| sys.path.insert(0, backend_dir) | |
| import models | |
| import document_utils | |
| import auth as auth_mod | |
| import jobs | |
| import subtitles | |
| from contextlib import asynccontextmanager | |
| def clean_temp_folder(): | |
| """Clean the temporary workspace folder on start""" | |
| if os.path.exists(TEMP_DIR): | |
| shutil.rmtree(TEMP_DIR) | |
| os.makedirs(TEMP_DIR, exist_ok=True) | |
| async def lifespan(app: FastAPI): | |
| # Startup logic | |
| logger.info("Application starting up...") | |
| clean_temp_folder() | |
| logger.info("Temporary folder cleared.") | |
| yield | |
| # Shutdown logic (if any) | |
| logger.info("Application shutting down...") | |
| app = FastAPI(title="Offline Translation API", version="1.0.0", lifespan=lifespan) | |
| # Auth Endpoints | |
| class LoginRequest(BaseModel): | |
| username: str | |
| password: str | |
| def login(req: LoginRequest): | |
| conn = auth_mod.get_db_conn() | |
| c = conn.cursor() | |
| c.execute("SELECT password, role FROM users WHERE username=?", (req.username,)) | |
| row = c.fetchone() | |
| conn.close() | |
| if not row or not auth_mod.verify_password(req.password, row[0]): | |
| raise HTTPException(status_code=401, detail="Invalid username or password") | |
| token = auth_mod.create_access_token(data={"sub": req.username, "role": row[1]}) | |
| return {"access_token": token, "token_type": "bearer", "role": row[1]} | |
| # Dependency for protected routes | |
| protected = [Depends(auth_mod.get_current_user)] | |
| admin_only = [Depends(auth_mod.require_admin)] | |
| # Enable CORS for frontend local development | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Directories | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| MODELS_DIR = os.path.join(BASE_DIR, "models") | |
| TEMP_DIR = os.path.join(BASE_DIR, "temp") | |
| os.makedirs(TEMP_DIR, exist_ok=True) | |
| # Initialize Model Manager (cached local models) | |
| model_manager = models.ModelManager(cache_dir=MODELS_DIR) | |
| # Request schemas | |
| class TranslationRequest(BaseModel): | |
| text: str | |
| src_lang: str | |
| tgt_lang: str | |
| class TTSRequest(BaseModel): | |
| text: str | |
| lang: str | |
| def health_check(): | |
| logger.info("Health check hit") | |
| return {"status": "healthy"} | |
| def ping(): | |
| return "pong" | |
| def get_models_status(user=Depends(auth_mod.get_current_user)): | |
| """ | |
| Check if the models cache exists and return which files are cached | |
| """ | |
| is_cached = False | |
| whisper_cached = False | |
| nllb_cached = False | |
| tts_cached = False | |
| # We can check folder existence inside the models dir | |
| if os.path.exists(MODELS_DIR): | |
| contents = os.listdir(MODELS_DIR) | |
| if len(contents) > 0: | |
| is_cached = True | |
| # Simple check for subfolders | |
| for item in contents: | |
| if "whisper" in item: | |
| whisper_cached = True | |
| if "nllb" in item: | |
| nllb_cached = True | |
| if "tts" in item: | |
| tts_cached = True | |
| return { | |
| "is_cached": is_cached, | |
| "whisper_cached": whisper_cached or is_cached, # fallback to is_cached if downloading via downloader | |
| "nllb_cached": nllb_cached or is_cached, | |
| "tts_cached": tts_cached or is_cached, | |
| "models_dir": MODELS_DIR | |
| } | |
| def detect_language_safe(text: str) -> str: | |
| """ | |
| Detect language and map to supported names | |
| """ | |
| try: | |
| lang_code = detect(text) | |
| mapping = { | |
| "en": "English", | |
| "hi": "Hindi", | |
| "mr": "Marathi" | |
| } | |
| return mapping.get(lang_code, "English") # Default to English | |
| except Exception: | |
| return "English" | |
| def api_detect_language(req: dict): | |
| text = req.get("text", "") | |
| if not text: | |
| return {"language": "English"} | |
| return {"language": detect_language_safe(text)} | |
| def get_job_status(job_id: str, user=Depends(auth_mod.get_current_user)): | |
| job = jobs.job_manager.get_job(job_id) | |
| if not job: | |
| raise HTTPException(status_code=404, detail="Job not found") | |
| return job | |
| def translate_text(req: TranslationRequest, user=Depends(auth_mod.get_current_user)): | |
| """ | |
| Translate text using offline NLLB-200 model | |
| """ | |
| try: | |
| src_lang = req.src_lang | |
| if src_lang.lower() == "auto": | |
| src_lang = detect_language_safe(req.text) | |
| translated = model_manager.translate(req.text, src_lang, req.tgt_lang) | |
| return {"translated_text": translated, "detected_src_lang": src_lang} | |
| except Exception as e: | |
| logger.error("Error in /api/translate-text: %s", e) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def transcribe_audio( | |
| file: UploadFile = File(...), | |
| model_size: str = Form("base"), | |
| language: str = Form("English"), | |
| user=Depends(auth_mod.get_current_user) | |
| ): | |
| """ | |
| Transcribe uploaded audio or video file using offline Whisper | |
| """ | |
| # Create unique session folder | |
| session_id = str(uuid.uuid4()) | |
| session_dir = os.path.join(TEMP_DIR, session_id) | |
| os.makedirs(session_dir, exist_ok=True) | |
| filename = file.filename or "unknown" | |
| file_ext = os.path.splitext(filename)[1].lower() | |
| input_path = os.path.join(session_dir, f"input{file_ext}") | |
| # Save uploaded file | |
| with open(input_path, "wb") as buffer: | |
| shutil.copyfileobj(file.file, buffer) | |
| try: | |
| audio_path = input_path | |
| # If it's a video file, extract audio track using FFmpeg | |
| is_video = file_ext in [".mp4", ".mov", ".avi", ".wmv", ".mkv", ".flv", ".webm"] | |
| if is_video: | |
| logger.info("Input is a video file. Extracting audio stream...") | |
| audio_path = os.path.join(session_dir, "extracted_audio.wav") | |
| subtitles.extract_audio(input_path, audio_path) | |
| # Transcribe audio file | |
| transcription_result = model_manager.transcribe( | |
| audio_path=audio_path, | |
| size=model_size, | |
| language=language | |
| ) | |
| # Generate SRT and VTT | |
| srt_content = subtitles.generate_srt(transcription_result["segments"]) | |
| vtt_content = subtitles.generate_vtt(transcription_result["segments"]) | |
| return { | |
| "text": transcription_result["text"], | |
| "segments": transcription_result["segments"], | |
| "srt": srt_content, | |
| "vtt": vtt_content, | |
| "session_id": session_id | |
| } | |
| except Exception as e: | |
| logger.error("Error in /api/transcribe-audio: %s", e) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def translate_audio( | |
| file: UploadFile = File(...), | |
| model_size: str = Form("base"), | |
| src_lang: str = Form("English"), | |
| tgt_lang: str = Form("Hindi"), | |
| user=Depends(auth_mod.get_current_user) | |
| ): | |
| """ | |
| Transcribe and translate an audio or video file | |
| """ | |
| session_id = str(uuid.uuid4()) | |
| session_dir = os.path.join(TEMP_DIR, session_id) | |
| os.makedirs(session_dir, exist_ok=True) | |
| filename = file.filename or "unknown" | |
| file_ext = os.path.splitext(filename)[1].lower() | |
| input_path = os.path.join(session_dir, f"input{file_ext}") | |
| with open(input_path, "wb") as buffer: | |
| shutil.copyfileobj(file.file, buffer) | |
| try: | |
| audio_path = input_path | |
| is_video = file_ext in [".mp4", ".mov", ".avi", ".wmv", ".mkv", ".flv", ".webm"] | |
| if is_video: | |
| audio_path = os.path.join(session_dir, "extracted_audio.wav") | |
| subtitles.extract_audio(input_path, audio_path) | |
| # Transcribe in source language | |
| transcription_result = model_manager.transcribe( | |
| audio_path=audio_path, | |
| size=model_size, | |
| language=src_lang | |
| ) | |
| actual_src = transcription_result.get("detected_language") | |
| if not actual_src or actual_src.lower() == "auto": | |
| actual_src = "English" | |
| # Translate the full text | |
| translated_text = model_manager.translate( | |
| text=transcription_result["text"], | |
| src_lang=actual_src, | |
| tgt_lang=tgt_lang | |
| ) | |
| # Batch translate all segments to construct translated subtitles | |
| seg_texts = [seg["text"] for seg in transcription_result["segments"]] | |
| translated_texts = model_manager.translate_batch( | |
| texts=seg_texts, | |
| src_lang=actual_src, | |
| tgt_lang=tgt_lang | |
| ) | |
| translated_segments = [] | |
| for seg, trans_text in zip(transcription_result["segments"], translated_texts): | |
| translated_segments.append({ | |
| "start": seg["start"], | |
| "end": seg["end"], | |
| "text": trans_text | |
| }) | |
| # Generate translated SRT & VTT | |
| translated_srt = subtitles.generate_srt(translated_segments) | |
| translated_vtt = subtitles.generate_vtt(translated_segments) | |
| # Still generate full text for UI results | |
| full_translated_text = " ".join([seg["text"] for seg in translated_segments]) | |
| # Merge audio segments if it was requested | |
| tts_output_filename = f"tts_{tgt_lang.lower()}.wav" | |
| tts_output_path = os.path.join(session_dir, tts_output_filename) | |
| # Use synchronized merge for audio translation as well | |
| subtitles.merge_audio_segments(translated_segments, session_dir, model_manager, tgt_lang) | |
| # The merge_audio_segments returns 'combined_audio.wav' in session_dir, let's rename it to expected name | |
| if os.path.exists(os.path.join(session_dir, "combined_audio.wav")): | |
| shutil.move(os.path.join(session_dir, "combined_audio.wav"), tts_output_path) | |
| return { | |
| "source_text": transcription_result["text"], | |
| "translated_text": full_translated_text, | |
| "source_segments": transcription_result["segments"], | |
| "translated_segments": translated_segments, | |
| "translated_srt": translated_srt, | |
| "translated_vtt": translated_vtt, | |
| "translated_audio_url": f"/api/download-file?session_id={session_id}&filename={tts_output_filename}", | |
| "session_id": session_id | |
| } | |
| except Exception as e: | |
| logger.error("Error in /api/translate-audio: %s", e) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def text_to_speech(req: TTSRequest, user=Depends(auth_mod.get_current_user)): | |
| """ | |
| Synthesize text into speech and return WAV audio binary stream | |
| """ | |
| session_id = str(uuid.uuid4()) | |
| session_dir = os.path.join(TEMP_DIR, session_id) | |
| os.makedirs(session_dir, exist_ok=True) | |
| output_path = os.path.join(session_dir, "tts.wav") | |
| try: | |
| model_manager.text_to_speech(req.text, req.lang, output_path) | |
| return FileResponse( | |
| output_path, | |
| media_type="audio/wav", | |
| filename=f"translated_speech_{req.lang.lower()}.wav" | |
| ) | |
| except Exception as e: | |
| logger.error("Error in /api/text-to-speech: %s", e) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def run_video_processing(job_id, session_id, input_path, session_dir, file_ext, model_size, src_lang, tgt_lang, burn_subtitles_option, overlay_voice_option): | |
| try: | |
| jobs.job_manager.update_job(job_id, status="processing", progress=5) | |
| # 1. Extract audio for ASR | |
| audio_path = os.path.join(session_dir, "extracted_audio.wav") | |
| subtitles.extract_audio(input_path, audio_path) | |
| jobs.job_manager.update_job(job_id, progress=15) | |
| # 2. Transcribe | |
| transcription_result = model_manager.transcribe( | |
| audio_path=audio_path, | |
| size=model_size, | |
| language=src_lang | |
| ) | |
| jobs.job_manager.update_job(job_id, progress=40) | |
| # Use detected language for translation steps | |
| actual_src = transcription_result.get("detected_language") | |
| if not actual_src or actual_src.lower() == "auto": | |
| actual_src = "English" # Final fallback | |
| # 3. Batch translate all segments & construct SRT | |
| seg_texts = [seg["text"] for seg in transcription_result["segments"]] | |
| translated_texts = model_manager.translate_batch( | |
| texts=seg_texts, | |
| src_lang=actual_src, | |
| tgt_lang=tgt_lang | |
| ) | |
| jobs.job_manager.update_job(job_id, progress=60) | |
| translated_segments = [] | |
| for seg, trans_text in zip(transcription_result["segments"], translated_texts): | |
| translated_segments.append({ | |
| "start": seg["start"], | |
| "end": seg["end"], | |
| "text": trans_text | |
| }) | |
| translated_srt = subtitles.generate_srt(translated_segments) | |
| translated_srt_path = os.path.join(session_dir, "translated_subs.srt") | |
| with open(translated_srt_path, "w", encoding="utf-8") as f: | |
| f.write(translated_srt) | |
| # 4. Generate subtitles-burned video | |
| current_video_stream = input_path | |
| if burn_subtitles_option: | |
| logger.info("Burning subtitles into video stream...") | |
| burned_video_path = os.path.join(session_dir, "burned_subtitles.mp4") | |
| subtitles.burn_subtitles(input_path, translated_srt_path, burned_video_path) | |
| current_video_stream = burned_video_path | |
| jobs.job_manager.update_job(job_id, progress=80) | |
| # 5. Handle audio overlay if selected | |
| output_video_filename = "" | |
| full_translated_text = "" | |
| if overlay_voice_option: | |
| logger.info("Generating voiceover segments and overlaying...") | |
| # Enhanced synchronization: synthesize each segment and merge them at correct timestamps | |
| tts_audio_path = subtitles.merge_audio_segments(translated_segments, session_dir, model_manager, tgt_lang) | |
| # Still generate full text for UI results | |
| full_translated_text = " ".join([seg["text"] for seg in translated_segments]) | |
| final_video_path = os.path.join(session_dir, "translated_final.mp4") | |
| subtitles.overlay_audio_on_video(current_video_stream, tts_audio_path, final_video_path) | |
| output_video_filename = "translated_final.mp4" | |
| else: | |
| if burn_subtitles_option: | |
| output_video_filename = "burned_subtitles.mp4" | |
| else: | |
| output_video_filename = f"input{file_ext}" | |
| jobs.job_manager.update_job( | |
| job_id, | |
| status="completed", | |
| progress=100, | |
| result={ | |
| "source_text": transcription_result["text"], | |
| "translated_text": full_translated_text, | |
| "translated_srt": translated_srt, | |
| "video_url": f"/api/download-file?session_id={session_id}&filename={output_video_filename}", | |
| "srt_url": f"/api/download-file?session_id={session_id}&filename=translated_subs.srt", | |
| "session_id": session_id | |
| } | |
| ) | |
| except Exception as e: | |
| logger.error("Error in job %s: %s", job_id, e) | |
| jobs.job_manager.update_job(job_id, status="failed", error=str(e)) | |
| async def api_process_video( | |
| background_tasks: BackgroundTasks, | |
| file: UploadFile = File(...), | |
| model_size: str = Form("base"), | |
| src_lang: str = Form("English"), | |
| tgt_lang: str = Form("Hindi"), | |
| burn_subtitles_option: bool = Form(True), | |
| overlay_voice_option: bool = Form(False), | |
| user=Depends(auth_mod.get_current_user) | |
| ): | |
| """ | |
| Complete video translation pipeline asynchronously | |
| """ | |
| session_id = str(uuid.uuid4()) | |
| session_dir = os.path.join(TEMP_DIR, session_id) | |
| os.makedirs(session_dir, exist_ok=True) | |
| filename = file.filename or "unknown" | |
| file_ext = os.path.splitext(filename)[1].lower() | |
| input_path = os.path.join(session_dir, f"input{file_ext}") | |
| with open(input_path, "wb") as buffer: | |
| shutil.copyfileobj(file.file, buffer) | |
| job_id = jobs.job_manager.create_job("video_processing") | |
| background_tasks.add_task( | |
| run_video_processing, | |
| job_id, session_id, input_path, session_dir, file_ext, model_size, src_lang, tgt_lang, | |
| burn_subtitles_option, overlay_voice_option | |
| ) | |
| return {"job_id": job_id} | |
| def download_file(session_id: str, filename: str): | |
| """ | |
| Endpoint to download session files safely | |
| """ | |
| file_path = os.path.join(TEMP_DIR, session_id, filename) | |
| if not os.path.exists(file_path): | |
| raise HTTPException(status_code=404, detail="Requested file not found.") | |
| # Set appropriate media type | |
| ext = os.path.splitext(filename)[1].lower() | |
| media_type = "application/octet-stream" | |
| if ext == ".mp4": | |
| media_type = "video/mp4" | |
| elif ext == ".wav": | |
| media_type = "audio/wav" | |
| elif ext in [".srt", ".vtt"]: | |
| media_type = "text/plain" | |
| elif ext == ".docx": | |
| media_type = "application/vnd.openxmlformats-officedocument.wordprocessingml.document" | |
| elif ext == ".pptx": | |
| media_type = "application/vnd.openxmlformats-officedocument.presentationml.presentation" | |
| elif ext == ".xlsx": | |
| media_type = "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" | |
| elif ext == ".pdf": | |
| media_type = "application/pdf" | |
| return FileResponse(file_path, media_type=media_type, filename=filename) | |
| def run_document_translation(job_id, session_id, input_path, output_path, file_ext, src_lang, tgt_lang, output_filename): | |
| try: | |
| jobs.job_manager.update_job(job_id, status="processing", progress=10) | |
| # Auto-detect if requested | |
| actual_src = src_lang | |
| if src_lang.lower() == "auto": | |
| preview = document_utils.extract_preview_text(input_path, file_ext) | |
| actual_src = detect_language_safe(preview) | |
| logger.info("Auto-detected document language: %s", actual_src) | |
| # Document translation logic | |
| if file_ext == ".docx": | |
| document_utils.translate_docx(input_path, output_path, model_manager, actual_src, tgt_lang) | |
| elif file_ext == ".pptx": | |
| document_utils.translate_pptx(input_path, output_path, model_manager, actual_src, tgt_lang) | |
| elif file_ext == ".xlsx": | |
| document_utils.translate_xlsx(input_path, output_path, model_manager, actual_src, tgt_lang) | |
| elif file_ext == ".pdf": | |
| document_utils.translate_pdf(input_path, output_path, model_manager, actual_src, tgt_lang) | |
| jobs.job_manager.update_job( | |
| job_id, | |
| status="completed", | |
| progress=100, | |
| result={ | |
| "output_url": f"/api/download-file?session_id={session_id}&filename={output_filename}", | |
| "session_id": session_id, | |
| "filename": output_filename, | |
| "detected_src_lang": actual_src | |
| } | |
| ) | |
| except Exception as e: | |
| logger.error("Error in job %s: %s", job_id, e) | |
| jobs.job_manager.update_job(job_id, status="failed", error=str(e)) | |
| async def api_translate_document( | |
| background_tasks: BackgroundTasks, | |
| file: UploadFile = File(...), | |
| src_lang: str = Form("English"), | |
| tgt_lang: str = Form("Hindi"), | |
| user=Depends(auth_mod.get_current_user) | |
| ): | |
| """ | |
| Translate uploaded document (docx, pptx, xlsx, pdf) asynchronously | |
| """ | |
| session_id = str(uuid.uuid4()) | |
| session_dir = os.path.join(TEMP_DIR, session_id) | |
| os.makedirs(session_dir, exist_ok=True) | |
| filename = file.filename or "unknown" | |
| file_ext = os.path.splitext(filename)[1].lower() | |
| if file_ext not in [".docx", ".pptx", ".xlsx", ".pdf"]: | |
| raise HTTPException(status_code=400, detail=f"Unsupported file format: {file_ext}") | |
| input_path = os.path.join(session_dir, f"input{file_ext}") | |
| output_filename = f"translated_{filename}" | |
| output_path = os.path.join(session_dir, output_filename) | |
| with open(input_path, "wb") as buffer: | |
| shutil.copyfileobj(file.file, buffer) | |
| job_id = jobs.job_manager.create_job("document_translation") | |
| background_tasks.add_task( | |
| run_document_translation, | |
| job_id, session_id, input_path, output_path, file_ext, src_lang, tgt_lang, output_filename | |
| ) | |
| return {"job_id": job_id} | |
| # Mount frontend build folder statically if it exists | |
| frontend_dist = os.path.abspath(os.path.join(BASE_DIR, "../frontend/dist")) | |
| logger.info("Checking for frontend at: %s", frontend_dist) | |
| if os.path.exists(frontend_dist): | |
| logger.info("Frontend found. Mounting at /") | |
| app.mount("/", StaticFiles(directory=frontend_dist, html=True), name="static") | |
| else: | |
| logger.warning("Frontend NOT found at %s", frontend_dist) | |
| def root_fallback(): | |
| return {"message": "API is running. Frontend assets not found."} | |
| if __name__ == "__main__": | |
| import uvicorn | |
| port = int(os.environ.get("PORT", 8000)) | |
| logger.info("Starting server on port %d...", port) | |
| uvicorn.run(app, host="0.0.0.0", port=port) | |