Spaces:
Running
Running
| import asyncio | |
| import logging | |
| import os | |
| import shutil | |
| import subprocess | |
| import queue | |
| from contextlib import asynccontextmanager | |
| from functools import partial | |
| from pathlib import Path | |
| from typing import List, AsyncGenerator, Optional | |
| import nltk | |
| from nltk.tokenize import sent_tokenize | |
| import warnings | |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException | |
| from fastapi.responses import StreamingResponse, JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field, field_validator | |
| from huggingface_hub import snapshot_download, hf_hub_download | |
| import ctranslate2 as ct2 | |
| import sentencepiece as spm | |
| # Suppress HF resume_download warning | |
| warnings.filterwarnings("ignore", category=FutureWarning) | |
| # --------------------------------------------------------- | |
| # CONFIGURATION | |
| # --------------------------------------------------------- | |
| MAX_CONCURRENT_TRANSLATIONS = int(os.getenv("MAX_CONCURRENT", "4")) | |
| TRANSLATION_TIMEOUT_SECONDS = int(os.getenv("TRANSLATION_TIMEOUT", "300")) # 5 minutes default | |
| MAX_TOKENS_PER_CHUNK = 400 # Conservative limit below NLLB's 512 token max | |
| # Globals | |
| try: | |
| nltk.data.find('tokenizers/punkt') | |
| except LookupError: | |
| nltk.download('punkt', quiet=True) | |
| try: | |
| nltk.data.find('tokenizers/punkt_tab') | |
| except LookupError: | |
| nltk.download('punkt_tab', quiet=True) | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
| log = logging.getLogger(__name__) | |
| # Use Float32 as requested | |
| MODEL_CACHE_DIR = Path(os.getenv("MODEL_CACHE_DIR", "./models/nllb-600m-f32")) | |
| SP_MODEL_FILE = Path("./models/sentencepiece.bpe.model") | |
| MODEL_HF_ID = "facebook/nllb-200-distilled-600M" | |
| # Global Objects | |
| translator: ct2.Translator = None | |
| sp: spm.SentencePieceProcessor = None | |
| # Concurrency control - semaphore limits concurrent translations | |
| translation_semaphore: asyncio.Semaphore = None | |
| class TranslateRequest(BaseModel): | |
| text: str = Field(..., description="Text to translate") | |
| src_lang: str = Field(default="eng_Latn", pattern=r'^[a-z]{3}_[A-Z][a-z]{3}') | |
| tgt_lang: str = Field(default="hin_Deva", pattern=r'^[a-z]{3}_[A-Z][a-z]{3}') | |
| def check_word_count(cls, v): | |
| if len(v.split()) > 10000: | |
| raise ValueError('Text exceeds 10k words') | |
| return v | |
| async def ensure_model() -> None: | |
| """HF download + CT2 convert (float32). Cache check first.""" | |
| model_bin = MODEL_CACHE_DIR / "model.bin" | |
| sp_file = SP_MODEL_FILE | |
| if model_bin.exists() and sp_file.exists(): | |
| log.info("CT2 model (float32) & SP cached.") | |
| return | |
| MODEL_CACHE_DIR.mkdir(parents=True, exist_ok=True) | |
| SP_MODEL_FILE.parent.mkdir(parents=True, exist_ok=True) | |
| log.info(f"Downloading/converting {MODEL_HF_ID}...") | |
| temp_dir = Path("./temp_hf") | |
| try: | |
| snapshot_download(repo_id=MODEL_HF_ID, local_dir=temp_dir) | |
| sp_path = hf_hub_download(repo_id=MODEL_HF_ID, filename="sentencepiece.bpe.model") | |
| shutil.copy(sp_path, sp_file) | |
| ct2_cmd = [ | |
| "ct2-transformers-converter", | |
| "--model", str(temp_dir), | |
| "--output_dir", str(MODEL_CACHE_DIR), | |
| "--quantization", "float32", | |
| "--force" | |
| ] | |
| subprocess.check_call(ct2_cmd) | |
| shutil.rmtree(temp_dir, ignore_errors=True) | |
| log.info("Model ready: float32 CT2.") | |
| except subprocess.CalledProcessError as e: | |
| log.error(f"CT2 converter failed (cmd: {ct2_cmd}). {e}") | |
| raise RuntimeError("Conversion failed; verify ctranslate2 & disk space.") | |
| except Exception as e: | |
| log.error(f"Setup failed: {e}") | |
| raise RuntimeError("Model prep failed; check HF access/disk.") | |
| # --------------------------------------------------------- | |
| # SENTENCE CHUNKING FOR LONG DOCUMENTS | |
| # --------------------------------------------------------- | |
| def split_into_sentences(text: str) -> List[str]: | |
| """ | |
| Split text into sentences while preserving paragraph structure. | |
| Uses NLTK's sent_tokenize for accurate sentence boundary detection. | |
| """ | |
| if not text.strip(): | |
| return [] | |
| sentences = [] | |
| paragraphs = text.split('\n') | |
| for i, para in enumerate(paragraphs): | |
| if not para.strip(): | |
| # Preserve empty lines as paragraph separators | |
| sentences.append('\n') | |
| continue | |
| # Split paragraph into sentences | |
| para_sentences = sent_tokenize(para.strip()) | |
| sentences.extend(para_sentences) | |
| # Add paragraph break after each paragraph (except last) | |
| if i < len(paragraphs) - 1: | |
| sentences.append('\n') | |
| return sentences | |
| def estimate_token_count(text: str) -> int: | |
| """ | |
| Estimate token count for a text segment. | |
| Uses SentencePiece tokenizer for accurate count. | |
| """ | |
| if sp is None: | |
| # Fallback: rough estimate based on words (avg 1.3 tokens per word) | |
| return int(len(text.split()) * 1.3) | |
| return len(sp.encode(text, out_type=str)) | |
| def merge_short_sentences(sentences: List[str], max_tokens: int = MAX_TOKENS_PER_CHUNK) -> List[str]: | |
| """ | |
| Merge short consecutive sentences into chunks that fit within token limit. | |
| This improves translation quality by providing more context. | |
| """ | |
| if not sentences: | |
| return [] | |
| chunks = [] | |
| current_chunk = [] | |
| current_tokens = 0 | |
| for sentence in sentences: | |
| if sentence == '\n': | |
| # Preserve paragraph breaks | |
| if current_chunk: | |
| chunks.append(' '.join(current_chunk)) | |
| current_chunk = [] | |
| current_tokens = 0 | |
| chunks.append('\n') | |
| continue | |
| sentence_tokens = estimate_token_count(sentence) | |
| # If single sentence exceeds limit, add it as its own chunk | |
| if sentence_tokens > max_tokens: | |
| if current_chunk: | |
| chunks.append(' '.join(current_chunk)) | |
| current_chunk = [] | |
| current_tokens = 0 | |
| chunks.append(sentence) | |
| continue | |
| # If adding this sentence would exceed limit, start new chunk | |
| if current_tokens + sentence_tokens > max_tokens and current_chunk: | |
| chunks.append(' '.join(current_chunk)) | |
| current_chunk = [sentence] | |
| current_tokens = sentence_tokens | |
| else: | |
| current_chunk.append(sentence) | |
| current_tokens += sentence_tokens | |
| # Don't forget the last chunk | |
| if current_chunk: | |
| chunks.append(' '.join(current_chunk)) | |
| return chunks | |
| # --------------------------------------------------------- | |
| # CORE TRANSLATION LOGIC (RAM ONLY - NO DISK WRITES) | |
| # --------------------------------------------------------- | |
| def _translate_single_chunk_sync(text: str, src: str, tgt: str) -> str: | |
| """ | |
| Synchronous translation of a single chunk. | |
| Used for non-streaming translation of document chunks. | |
| """ | |
| if not text.strip() or text == '\n': | |
| return text | |
| try: | |
| tokens = sp.encode(text, out_type=str) | |
| tokens.insert(0, src) | |
| tokens.append("</s>") | |
| results = translator.translate_batch( | |
| [tokens], | |
| target_prefix=[[tgt]], | |
| max_decoding_length=512, | |
| beam_size=3, # Higher quality | |
| repetition_penalty=1.2 | |
| ) | |
| out_tokens = results[0].hypotheses[0] | |
| if out_tokens and out_tokens[0] == tgt: | |
| out_tokens = out_tokens[1:] | |
| decoded = sp.decode(out_tokens) | |
| decoded = decoded.replace("▁", " ") | |
| while " " in decoded: | |
| decoded = decoded.replace(" ", " ") | |
| return decoded.strip() | |
| except Exception as e: | |
| log.error(f"Chunk translation error: {e}") | |
| return f"[Translation Error: {str(e)[:50]}]" | |
| def _run_translation_sync(text: str, src: str, tgt: str, token_queue: queue.Queue): | |
| """ | |
| Blocking worker running in thread. | |
| PURE IN-MEMORY OPERATION. | |
| """ | |
| try: | |
| # 1. Tokenize (In-Memory) | |
| source_tokens = sp.encode(text, out_type=str) | |
| source_tokens.insert(0, src) | |
| source_tokens.append("</s>") | |
| # 2. Callback for tokens | |
| def callback(step_result): | |
| token = step_result.token | |
| token_queue.put(token) | |
| return False # Continue generation | |
| # 3. Inference (In-Memory via CTranslate2) | |
| translator.translate_batch( | |
| [source_tokens], | |
| target_prefix=[[tgt]], | |
| max_decoding_length=512, | |
| beam_size=1, # Greedy for speed | |
| repetition_penalty=1.2, | |
| callback=callback | |
| ) | |
| except Exception as e: | |
| token_queue.put(f"ERROR: {e}") | |
| finally: | |
| token_queue.put(None) # Sentinel to stop | |
| async def translate_stream_chunk(text: str, src: str, tgt: str) -> AsyncGenerator[str, None]: | |
| """Async generator that consumes the threaded queue.""" | |
| token_queue = queue.Queue() | |
| loop = asyncio.get_running_loop() | |
| # Start blocking job in thread | |
| loop.run_in_executor( | |
| None, | |
| partial(_run_translation_sync, text, src, tgt, token_queue) | |
| ) | |
| # Track previous token to handle spaces properly | |
| prev_token_ended_with_space = False | |
| while True: | |
| token = await loop.run_in_executor(None, token_queue.get) | |
| if token is None: | |
| break | |
| if token.startswith("ERROR:"): | |
| yield token | |
| break | |
| # --- SPACE FIX: Proper handling of SentencePiece space markers --- | |
| # SentencePiece uses U+2581 (▁) to denote spaces between words | |
| # We need to handle this properly | |
| # Skip special tokens | |
| if token in [src, tgt, "</s>", "<s>", "<pad>"]: | |
| continue | |
| # Check if token is a space marker | |
| if token == "▁": | |
| # This is a standalone space token | |
| yield " " | |
| prev_token_ended_with_space = True | |
| continue | |
| # Check if token starts with space marker | |
| if token.startswith("▁"): | |
| # Token starts with space marker | |
| # Remove the marker for decoding | |
| clean_token = token[1:] if len(token) > 1 else "" | |
| if clean_token: | |
| # Decode the clean token | |
| decoded_word = sp.decode([clean_token]) | |
| # If previous token didn't end with space, add one | |
| if not prev_token_ended_with_space: | |
| yield " " + decoded_word | |
| else: | |
| yield decoded_word | |
| prev_token_ended_with_space = False | |
| else: | |
| # Token was just the space marker | |
| yield " " | |
| prev_token_ended_with_space = True | |
| else: | |
| # Regular token without space marker | |
| decoded_word = sp.decode([token]) | |
| # If previous token ended with space, don't add another | |
| if prev_token_ended_with_space: | |
| yield decoded_word | |
| else: | |
| # Check if decoded word starts with space (some tokens might decode with spaces) | |
| if decoded_word.startswith(" "): | |
| yield decoded_word | |
| else: | |
| # No space needed before this word | |
| yield decoded_word | |
| prev_token_ended_with_space = False | |
| async def translate_long_document_stream(text: str, src: str, tgt: str) -> AsyncGenerator[str, None]: | |
| """ | |
| Translate long documents by chunking into sentences. | |
| Streams results chunk by chunk for responsive UI. | |
| """ | |
| # Split text into manageable chunks | |
| sentences = split_into_sentences(text) | |
| chunks = merge_short_sentences(sentences, MAX_TOKENS_PER_CHUNK) | |
| log.info(f"Long document: {len(text)} chars -> {len(chunks)} chunks") | |
| for chunk in chunks: | |
| if chunk == '\n': | |
| yield '\n' | |
| continue | |
| if not chunk.strip(): | |
| continue | |
| # Translate this chunk using streaming | |
| async for token in translate_stream_chunk(chunk, src, tgt): | |
| yield token | |
| # Add space between chunks (but not after newlines) | |
| yield " " | |
| async def translate_long_document_sync(text: str, src: str, tgt: str) -> str: | |
| """ | |
| Translate long documents synchronously (for /translate_sync endpoint). | |
| Returns complete translated text. | |
| """ | |
| loop = asyncio.get_running_loop() | |
| # Split text into manageable chunks | |
| sentences = split_into_sentences(text) | |
| chunks = merge_short_sentences(sentences, MAX_TOKENS_PER_CHUNK) | |
| log.info(f"Long document sync: {len(text)} chars -> {len(chunks)} chunks") | |
| translated_parts = [] | |
| for chunk in chunks: | |
| if chunk == '\n': | |
| translated_parts.append('\n') | |
| continue | |
| if not chunk.strip(): | |
| continue | |
| # Run translation in thread to not block event loop | |
| translated = await loop.run_in_executor( | |
| None, | |
| partial(_translate_single_chunk_sync, chunk, src, tgt) | |
| ) | |
| translated_parts.append(translated) | |
| # Join with spaces, but handle newlines properly | |
| result = [] | |
| for i, part in enumerate(translated_parts): | |
| if part == '\n': | |
| result.append('\n') | |
| else: | |
| if result and result[-1] != '\n': | |
| result.append(' ') | |
| result.append(part) | |
| return ''.join(result).strip() | |
| # --------------------------------------------------------- | |
| # STARTUP & SHUTDOWN (Modern Lifespan Pattern) | |
| # --------------------------------------------------------- | |
| async def lifespan(app: FastAPI): | |
| """Modern lifespan context manager for startup/shutdown.""" | |
| global translator, sp, translation_semaphore | |
| try: | |
| await ensure_model() | |
| cpu_count = os.cpu_count() or 4 | |
| # Optimal threading configuration for concurrent requests: | |
| # - inter_threads: enables batch-level parallelism (process multiple requests) | |
| # - intra_threads: threads per single operation | |
| # Rule: inter_threads * intra_threads <= cpu_count | |
| inter_threads = max(1, cpu_count // 2) # Parallel batch processing | |
| intra_threads = 2 # Per-operation threads | |
| log.info(f"CPU cores: {cpu_count}, inter_threads: {inter_threads}, intra_threads: {intra_threads}") | |
| # Load Model with optimized threading | |
| translator = ct2.Translator( | |
| str(MODEL_CACHE_DIR), | |
| device="cpu", | |
| compute_type="float32", | |
| inter_threads=inter_threads, | |
| intra_threads=intra_threads | |
| ) | |
| sp = spm.SentencePieceProcessor(model_file=str(SP_MODEL_FILE)) | |
| # Initialize concurrency semaphore | |
| translation_semaphore = asyncio.Semaphore(MAX_CONCURRENT_TRANSLATIONS) | |
| log.info(f"Concurrency limit: {MAX_CONCURRENT_TRANSLATIONS} simultaneous translations") | |
| # --- WARMUP START --- | |
| log.info("Warming up model...") | |
| warmup_src = "Hello world" | |
| warmup_tokens = sp.encode(warmup_src, out_type=str) | |
| warmup_tokens.insert(0, "eng_Latn") | |
| warmup_tokens.append("</s>") | |
| # Run a silent inference to load weights into RAM | |
| translator.translate_batch( | |
| [warmup_tokens], | |
| target_prefix=[["hin_Deva"]], | |
| max_decoding_length=10, | |
| beam_size=1 | |
| ) | |
| log.info("Model Warmed Up & Ready!") | |
| # --- WARMUP END --- | |
| except Exception as e: | |
| log.error(f"Fatal startup: {e}") | |
| raise | |
| yield # Application runs here | |
| # Shutdown cleanup (if needed) | |
| log.info("Shutting down NLLB translator...") | |
| app = FastAPI( | |
| title="NLLB Realtime Translator", | |
| version="2.2", # Version bump for new features | |
| docs_url="/docs", | |
| lifespan=lifespan | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # --------------------------------------------------------- | |
| # ENDPOINTS | |
| # --------------------------------------------------------- | |
| async def translate_stream_http(req: TranslateRequest): | |
| """ | |
| Streaming translation endpoint. | |
| Handles long documents via sentence chunking. | |
| Limited by semaphore to prevent overload. | |
| Now uses Server-Sent Events (SSE) to bypass proxy buffering. | |
| """ | |
| import json | |
| async def event_generator(): | |
| try: | |
| async with asyncio.timeout(TRANSLATION_TIMEOUT_SECONDS): | |
| async with translation_semaphore: | |
| # Use long document handler for all requests | |
| async for token in translate_long_document_stream(req.text, req.src_lang, req.tgt_lang): | |
| yield f"data: {json.dumps({'token': token})}\n\n" | |
| except asyncio.TimeoutError: | |
| yield f"data: {json.dumps({'error': f'Translation timed out after {TRANSLATION_TIMEOUT_SECONDS} seconds'})}\n\n" | |
| except Exception as e: | |
| yield f"data: {json.dumps({'error': str(e)})}\n\n" | |
| finally: | |
| yield "data: [DONE]\n\n" | |
| return StreamingResponse(event_generator(), media_type="text/event-stream") | |
| async def ws_translate(websocket: WebSocket): | |
| """ | |
| WebSocket streaming translation. | |
| Handles long documents via sentence chunking. | |
| """ | |
| await websocket.accept() | |
| try: | |
| data = await websocket.receive_json() | |
| req = TranslateRequest(**data) | |
| async with asyncio.timeout(TRANSLATION_TIMEOUT_SECONDS): | |
| async with translation_semaphore: | |
| async for token in translate_long_document_stream(req.text, req.src_lang, req.tgt_lang): | |
| await websocket.send_text(token) | |
| except asyncio.TimeoutError: | |
| await websocket.send_text(f"\n[ERROR: Translation timed out after {TRANSLATION_TIMEOUT_SECONDS} seconds]") | |
| except (WebSocketDisconnect, ValueError) as e: | |
| log.info(f"WS Disconnect/Error: {e}") | |
| except Exception as e: | |
| await websocket.send_text(f"error:{str(e)}") | |
| async def translate_sync(req: TranslateRequest): | |
| """ | |
| Synchronous translation endpoint (returns complete result). | |
| Handles long documents via sentence chunking. | |
| Limited by semaphore and timeout. | |
| """ | |
| try: | |
| async with asyncio.timeout(TRANSLATION_TIMEOUT_SECONDS): | |
| async with translation_semaphore: | |
| translated = await translate_long_document_sync(req.text, req.src_lang, req.tgt_lang) | |
| return JSONResponse({"translation": translated}) | |
| except asyncio.TimeoutError: | |
| raise HTTPException(408, f"Translation timed out after {TRANSLATION_TIMEOUT_SECONDS} seconds") | |
| except Exception as e: | |
| raise HTTPException(500, str(e)) | |
| async def health(): | |
| """Health check endpoint with detailed status.""" | |
| healthy = translator is not None and sp is not None | |
| return { | |
| "status": "healthy" if healthy else "model unavailable", | |
| "type": "float32", | |
| "max_concurrent": MAX_CONCURRENT_TRANSLATIONS, | |
| "timeout_seconds": TRANSLATION_TIMEOUT_SECONDS, | |
| "version": "2.2" | |
| } | |
| if __name__ == "__main__": | |
| import uvicorn | |
| port = int(os.getenv("PORT", 7860)) | |
| uvicorn.run("app:app", host="0.0.0.0", port=port, reload=False) |