import asyncio import logging import os import shutil import subprocess import queue from contextlib import asynccontextmanager from functools import partial from pathlib import Path from typing import List, AsyncGenerator, Optional import nltk from nltk.tokenize import sent_tokenize import warnings from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException from fastapi.responses import StreamingResponse, JSONResponse from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field, field_validator from huggingface_hub import snapshot_download, hf_hub_download import ctranslate2 as ct2 import sentencepiece as spm # Suppress HF resume_download warning warnings.filterwarnings("ignore", category=FutureWarning) # --------------------------------------------------------- # CONFIGURATION # --------------------------------------------------------- MAX_CONCURRENT_TRANSLATIONS = int(os.getenv("MAX_CONCURRENT", "4")) TRANSLATION_TIMEOUT_SECONDS = int(os.getenv("TRANSLATION_TIMEOUT", "300")) # 5 minutes default MAX_TOKENS_PER_CHUNK = 400 # Conservative limit below NLLB's 512 token max # Globals try: nltk.data.find('tokenizers/punkt') except LookupError: nltk.download('punkt', quiet=True) try: nltk.data.find('tokenizers/punkt_tab') except LookupError: nltk.download('punkt_tab', quiet=True) logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') log = logging.getLogger(__name__) # Use Float32 as requested MODEL_CACHE_DIR = Path(os.getenv("MODEL_CACHE_DIR", "./models/nllb-600m-f32")) SP_MODEL_FILE = Path("./models/sentencepiece.bpe.model") MODEL_HF_ID = "facebook/nllb-200-distilled-600M" # Global Objects translator: ct2.Translator = None sp: spm.SentencePieceProcessor = None # Concurrency control - semaphore limits concurrent translations translation_semaphore: asyncio.Semaphore = None class TranslateRequest(BaseModel): text: str = Field(..., description="Text to translate") src_lang: str = Field(default="eng_Latn", pattern=r'^[a-z]{3}_[A-Z][a-z]{3}') tgt_lang: str = Field(default="hin_Deva", pattern=r'^[a-z]{3}_[A-Z][a-z]{3}') @field_validator('text') @classmethod def check_word_count(cls, v): if len(v.split()) > 10000: raise ValueError('Text exceeds 10k words') return v async def ensure_model() -> None: """HF download + CT2 convert (float32). Cache check first.""" model_bin = MODEL_CACHE_DIR / "model.bin" sp_file = SP_MODEL_FILE if model_bin.exists() and sp_file.exists(): log.info("CT2 model (float32) & SP cached.") return MODEL_CACHE_DIR.mkdir(parents=True, exist_ok=True) SP_MODEL_FILE.parent.mkdir(parents=True, exist_ok=True) log.info(f"Downloading/converting {MODEL_HF_ID}...") temp_dir = Path("./temp_hf") try: snapshot_download(repo_id=MODEL_HF_ID, local_dir=temp_dir) sp_path = hf_hub_download(repo_id=MODEL_HF_ID, filename="sentencepiece.bpe.model") shutil.copy(sp_path, sp_file) ct2_cmd = [ "ct2-transformers-converter", "--model", str(temp_dir), "--output_dir", str(MODEL_CACHE_DIR), "--quantization", "float32", "--force" ] subprocess.check_call(ct2_cmd) shutil.rmtree(temp_dir, ignore_errors=True) log.info("Model ready: float32 CT2.") except subprocess.CalledProcessError as e: log.error(f"CT2 converter failed (cmd: {ct2_cmd}). {e}") raise RuntimeError("Conversion failed; verify ctranslate2 & disk space.") except Exception as e: log.error(f"Setup failed: {e}") raise RuntimeError("Model prep failed; check HF access/disk.") # --------------------------------------------------------- # SENTENCE CHUNKING FOR LONG DOCUMENTS # --------------------------------------------------------- def split_into_sentences(text: str) -> List[str]: """ Split text into sentences while preserving paragraph structure. Uses NLTK's sent_tokenize for accurate sentence boundary detection. """ if not text.strip(): return [] sentences = [] paragraphs = text.split('\n') for i, para in enumerate(paragraphs): if not para.strip(): # Preserve empty lines as paragraph separators sentences.append('\n') continue # Split paragraph into sentences para_sentences = sent_tokenize(para.strip()) sentences.extend(para_sentences) # Add paragraph break after each paragraph (except last) if i < len(paragraphs) - 1: sentences.append('\n') return sentences def estimate_token_count(text: str) -> int: """ Estimate token count for a text segment. Uses SentencePiece tokenizer for accurate count. """ if sp is None: # Fallback: rough estimate based on words (avg 1.3 tokens per word) return int(len(text.split()) * 1.3) return len(sp.encode(text, out_type=str)) def merge_short_sentences(sentences: List[str], max_tokens: int = MAX_TOKENS_PER_CHUNK) -> List[str]: """ Merge short consecutive sentences into chunks that fit within token limit. This improves translation quality by providing more context. """ if not sentences: return [] chunks = [] current_chunk = [] current_tokens = 0 for sentence in sentences: if sentence == '\n': # Preserve paragraph breaks if current_chunk: chunks.append(' '.join(current_chunk)) current_chunk = [] current_tokens = 0 chunks.append('\n') continue sentence_tokens = estimate_token_count(sentence) # If single sentence exceeds limit, add it as its own chunk if sentence_tokens > max_tokens: if current_chunk: chunks.append(' '.join(current_chunk)) current_chunk = [] current_tokens = 0 chunks.append(sentence) continue # If adding this sentence would exceed limit, start new chunk if current_tokens + sentence_tokens > max_tokens and current_chunk: chunks.append(' '.join(current_chunk)) current_chunk = [sentence] current_tokens = sentence_tokens else: current_chunk.append(sentence) current_tokens += sentence_tokens # Don't forget the last chunk if current_chunk: chunks.append(' '.join(current_chunk)) return chunks # --------------------------------------------------------- # CORE TRANSLATION LOGIC (RAM ONLY - NO DISK WRITES) # --------------------------------------------------------- def _translate_single_chunk_sync(text: str, src: str, tgt: str) -> str: """ Synchronous translation of a single chunk. Used for non-streaming translation of document chunks. """ if not text.strip() or text == '\n': return text try: tokens = sp.encode(text, out_type=str) tokens.insert(0, src) tokens.append("") results = translator.translate_batch( [tokens], target_prefix=[[tgt]], max_decoding_length=512, beam_size=3, # Higher quality repetition_penalty=1.2 ) out_tokens = results[0].hypotheses[0] if out_tokens and out_tokens[0] == tgt: out_tokens = out_tokens[1:] decoded = sp.decode(out_tokens) decoded = decoded.replace("▁", " ") while " " in decoded: decoded = decoded.replace(" ", " ") return decoded.strip() except Exception as e: log.error(f"Chunk translation error: {e}") return f"[Translation Error: {str(e)[:50]}]" def _run_translation_sync(text: str, src: str, tgt: str, token_queue: queue.Queue): """ Blocking worker running in thread. PURE IN-MEMORY OPERATION. """ try: # 1. Tokenize (In-Memory) source_tokens = sp.encode(text, out_type=str) source_tokens.insert(0, src) source_tokens.append("") # 2. Callback for tokens def callback(step_result): token = step_result.token token_queue.put(token) return False # Continue generation # 3. Inference (In-Memory via CTranslate2) translator.translate_batch( [source_tokens], target_prefix=[[tgt]], max_decoding_length=512, beam_size=1, # Greedy for speed repetition_penalty=1.2, callback=callback ) except Exception as e: token_queue.put(f"ERROR: {e}") finally: token_queue.put(None) # Sentinel to stop async def translate_stream_chunk(text: str, src: str, tgt: str) -> AsyncGenerator[str, None]: """Async generator that consumes the threaded queue.""" token_queue = queue.Queue() loop = asyncio.get_running_loop() # Start blocking job in thread loop.run_in_executor( None, partial(_run_translation_sync, text, src, tgt, token_queue) ) # Track previous token to handle spaces properly prev_token_ended_with_space = False while True: token = await loop.run_in_executor(None, token_queue.get) if token is None: break if token.startswith("ERROR:"): yield token break # --- SPACE FIX: Proper handling of SentencePiece space markers --- # SentencePiece uses U+2581 (▁) to denote spaces between words # We need to handle this properly # Skip special tokens if token in [src, tgt, "", "", ""]: continue # Check if token is a space marker if token == "▁": # This is a standalone space token yield " " prev_token_ended_with_space = True continue # Check if token starts with space marker if token.startswith("▁"): # Token starts with space marker # Remove the marker for decoding clean_token = token[1:] if len(token) > 1 else "" if clean_token: # Decode the clean token decoded_word = sp.decode([clean_token]) # If previous token didn't end with space, add one if not prev_token_ended_with_space: yield " " + decoded_word else: yield decoded_word prev_token_ended_with_space = False else: # Token was just the space marker yield " " prev_token_ended_with_space = True else: # Regular token without space marker decoded_word = sp.decode([token]) # If previous token ended with space, don't add another if prev_token_ended_with_space: yield decoded_word else: # Check if decoded word starts with space (some tokens might decode with spaces) if decoded_word.startswith(" "): yield decoded_word else: # No space needed before this word yield decoded_word prev_token_ended_with_space = False async def translate_long_document_stream(text: str, src: str, tgt: str) -> AsyncGenerator[str, None]: """ Translate long documents by chunking into sentences. Streams results chunk by chunk for responsive UI. """ # Split text into manageable chunks sentences = split_into_sentences(text) chunks = merge_short_sentences(sentences, MAX_TOKENS_PER_CHUNK) log.info(f"Long document: {len(text)} chars -> {len(chunks)} chunks") for chunk in chunks: if chunk == '\n': yield '\n' continue if not chunk.strip(): continue # Translate this chunk using streaming async for token in translate_stream_chunk(chunk, src, tgt): yield token # Add space between chunks (but not after newlines) yield " " async def translate_long_document_sync(text: str, src: str, tgt: str) -> str: """ Translate long documents synchronously (for /translate_sync endpoint). Returns complete translated text. """ loop = asyncio.get_running_loop() # Split text into manageable chunks sentences = split_into_sentences(text) chunks = merge_short_sentences(sentences, MAX_TOKENS_PER_CHUNK) log.info(f"Long document sync: {len(text)} chars -> {len(chunks)} chunks") translated_parts = [] for chunk in chunks: if chunk == '\n': translated_parts.append('\n') continue if not chunk.strip(): continue # Run translation in thread to not block event loop translated = await loop.run_in_executor( None, partial(_translate_single_chunk_sync, chunk, src, tgt) ) translated_parts.append(translated) # Join with spaces, but handle newlines properly result = [] for i, part in enumerate(translated_parts): if part == '\n': result.append('\n') else: if result and result[-1] != '\n': result.append(' ') result.append(part) return ''.join(result).strip() # --------------------------------------------------------- # STARTUP & SHUTDOWN (Modern Lifespan Pattern) # --------------------------------------------------------- @asynccontextmanager async def lifespan(app: FastAPI): """Modern lifespan context manager for startup/shutdown.""" global translator, sp, translation_semaphore try: await ensure_model() cpu_count = os.cpu_count() or 4 # Optimal threading configuration for concurrent requests: # - inter_threads: enables batch-level parallelism (process multiple requests) # - intra_threads: threads per single operation # Rule: inter_threads * intra_threads <= cpu_count inter_threads = max(1, cpu_count // 2) # Parallel batch processing intra_threads = 2 # Per-operation threads log.info(f"CPU cores: {cpu_count}, inter_threads: {inter_threads}, intra_threads: {intra_threads}") # Load Model with optimized threading translator = ct2.Translator( str(MODEL_CACHE_DIR), device="cpu", compute_type="float32", inter_threads=inter_threads, intra_threads=intra_threads ) sp = spm.SentencePieceProcessor(model_file=str(SP_MODEL_FILE)) # Initialize concurrency semaphore translation_semaphore = asyncio.Semaphore(MAX_CONCURRENT_TRANSLATIONS) log.info(f"Concurrency limit: {MAX_CONCURRENT_TRANSLATIONS} simultaneous translations") # --- WARMUP START --- log.info("Warming up model...") warmup_src = "Hello world" warmup_tokens = sp.encode(warmup_src, out_type=str) warmup_tokens.insert(0, "eng_Latn") warmup_tokens.append("") # Run a silent inference to load weights into RAM translator.translate_batch( [warmup_tokens], target_prefix=[["hin_Deva"]], max_decoding_length=10, beam_size=1 ) log.info("Model Warmed Up & Ready!") # --- WARMUP END --- except Exception as e: log.error(f"Fatal startup: {e}") raise yield # Application runs here # Shutdown cleanup (if needed) log.info("Shutting down NLLB translator...") app = FastAPI( title="NLLB Realtime Translator", version="2.2", # Version bump for new features docs_url="/docs", lifespan=lifespan ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # --------------------------------------------------------- # ENDPOINTS # --------------------------------------------------------- @app.post("/translate") async def translate_stream_http(req: TranslateRequest): """ Streaming translation endpoint. Handles long documents via sentence chunking. Limited by semaphore to prevent overload. Now uses Server-Sent Events (SSE) to bypass proxy buffering. """ import json async def event_generator(): try: async with asyncio.timeout(TRANSLATION_TIMEOUT_SECONDS): async with translation_semaphore: # Use long document handler for all requests async for token in translate_long_document_stream(req.text, req.src_lang, req.tgt_lang): yield f"data: {json.dumps({'token': token})}\n\n" except asyncio.TimeoutError: yield f"data: {json.dumps({'error': f'Translation timed out after {TRANSLATION_TIMEOUT_SECONDS} seconds'})}\n\n" except Exception as e: yield f"data: {json.dumps({'error': str(e)})}\n\n" finally: yield "data: [DONE]\n\n" return StreamingResponse(event_generator(), media_type="text/event-stream") @app.websocket("/ws/translate") async def ws_translate(websocket: WebSocket): """ WebSocket streaming translation. Handles long documents via sentence chunking. """ await websocket.accept() try: data = await websocket.receive_json() req = TranslateRequest(**data) async with asyncio.timeout(TRANSLATION_TIMEOUT_SECONDS): async with translation_semaphore: async for token in translate_long_document_stream(req.text, req.src_lang, req.tgt_lang): await websocket.send_text(token) except asyncio.TimeoutError: await websocket.send_text(f"\n[ERROR: Translation timed out after {TRANSLATION_TIMEOUT_SECONDS} seconds]") except (WebSocketDisconnect, ValueError) as e: log.info(f"WS Disconnect/Error: {e}") except Exception as e: await websocket.send_text(f"error:{str(e)}") @app.post("/translate_sync") async def translate_sync(req: TranslateRequest): """ Synchronous translation endpoint (returns complete result). Handles long documents via sentence chunking. Limited by semaphore and timeout. """ try: async with asyncio.timeout(TRANSLATION_TIMEOUT_SECONDS): async with translation_semaphore: translated = await translate_long_document_sync(req.text, req.src_lang, req.tgt_lang) return JSONResponse({"translation": translated}) except asyncio.TimeoutError: raise HTTPException(408, f"Translation timed out after {TRANSLATION_TIMEOUT_SECONDS} seconds") except Exception as e: raise HTTPException(500, str(e)) @app.get("/health") async def health(): """Health check endpoint with detailed status.""" healthy = translator is not None and sp is not None return { "status": "healthy" if healthy else "model unavailable", "type": "float32", "max_concurrent": MAX_CONCURRENT_TRANSLATIONS, "timeout_seconds": TRANSLATION_TIMEOUT_SECONDS, "version": "2.2" } if __name__ == "__main__": import uvicorn port = int(os.getenv("PORT", 7860)) uvicorn.run("app:app", host="0.0.0.0", port=port, reload=False)