Translate / app.py
Rajhuggingface4253's picture
Update app.py
29d2876 verified
import asyncio
import logging
import os
import shutil
import subprocess
import queue
from contextlib import asynccontextmanager
from functools import partial
from pathlib import Path
from typing import List, AsyncGenerator, Optional
import nltk
from nltk.tokenize import sent_tokenize
import warnings
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException
from fastapi.responses import StreamingResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field, field_validator
from huggingface_hub import snapshot_download, hf_hub_download
import ctranslate2 as ct2
import sentencepiece as spm
# Suppress HF resume_download warning
warnings.filterwarnings("ignore", category=FutureWarning)
# ---------------------------------------------------------
# CONFIGURATION
# ---------------------------------------------------------
MAX_CONCURRENT_TRANSLATIONS = int(os.getenv("MAX_CONCURRENT", "4"))
TRANSLATION_TIMEOUT_SECONDS = int(os.getenv("TRANSLATION_TIMEOUT", "300")) # 5 minutes default
MAX_TOKENS_PER_CHUNK = 400 # Conservative limit below NLLB's 512 token max
# Globals
try:
nltk.data.find('tokenizers/punkt')
except LookupError:
nltk.download('punkt', quiet=True)
try:
nltk.data.find('tokenizers/punkt_tab')
except LookupError:
nltk.download('punkt_tab', quiet=True)
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
log = logging.getLogger(__name__)
# Use Float32 as requested
MODEL_CACHE_DIR = Path(os.getenv("MODEL_CACHE_DIR", "./models/nllb-600m-f32"))
SP_MODEL_FILE = Path("./models/sentencepiece.bpe.model")
MODEL_HF_ID = "facebook/nllb-200-distilled-600M"
# Global Objects
translator: ct2.Translator = None
sp: spm.SentencePieceProcessor = None
# Concurrency control - semaphore limits concurrent translations
translation_semaphore: asyncio.Semaphore = None
class TranslateRequest(BaseModel):
text: str = Field(..., description="Text to translate")
src_lang: str = Field(default="eng_Latn", pattern=r'^[a-z]{3}_[A-Z][a-z]{3}')
tgt_lang: str = Field(default="hin_Deva", pattern=r'^[a-z]{3}_[A-Z][a-z]{3}')
@field_validator('text')
@classmethod
def check_word_count(cls, v):
if len(v.split()) > 10000:
raise ValueError('Text exceeds 10k words')
return v
async def ensure_model() -> None:
"""HF download + CT2 convert (float32). Cache check first."""
model_bin = MODEL_CACHE_DIR / "model.bin"
sp_file = SP_MODEL_FILE
if model_bin.exists() and sp_file.exists():
log.info("CT2 model (float32) & SP cached.")
return
MODEL_CACHE_DIR.mkdir(parents=True, exist_ok=True)
SP_MODEL_FILE.parent.mkdir(parents=True, exist_ok=True)
log.info(f"Downloading/converting {MODEL_HF_ID}...")
temp_dir = Path("./temp_hf")
try:
snapshot_download(repo_id=MODEL_HF_ID, local_dir=temp_dir)
sp_path = hf_hub_download(repo_id=MODEL_HF_ID, filename="sentencepiece.bpe.model")
shutil.copy(sp_path, sp_file)
ct2_cmd = [
"ct2-transformers-converter",
"--model", str(temp_dir),
"--output_dir", str(MODEL_CACHE_DIR),
"--quantization", "float32",
"--force"
]
subprocess.check_call(ct2_cmd)
shutil.rmtree(temp_dir, ignore_errors=True)
log.info("Model ready: float32 CT2.")
except subprocess.CalledProcessError as e:
log.error(f"CT2 converter failed (cmd: {ct2_cmd}). {e}")
raise RuntimeError("Conversion failed; verify ctranslate2 & disk space.")
except Exception as e:
log.error(f"Setup failed: {e}")
raise RuntimeError("Model prep failed; check HF access/disk.")
# ---------------------------------------------------------
# SENTENCE CHUNKING FOR LONG DOCUMENTS
# ---------------------------------------------------------
def split_into_sentences(text: str) -> List[str]:
"""
Split text into sentences while preserving paragraph structure.
Uses NLTK's sent_tokenize for accurate sentence boundary detection.
"""
if not text.strip():
return []
sentences = []
paragraphs = text.split('\n')
for i, para in enumerate(paragraphs):
if not para.strip():
# Preserve empty lines as paragraph separators
sentences.append('\n')
continue
# Split paragraph into sentences
para_sentences = sent_tokenize(para.strip())
sentences.extend(para_sentences)
# Add paragraph break after each paragraph (except last)
if i < len(paragraphs) - 1:
sentences.append('\n')
return sentences
def estimate_token_count(text: str) -> int:
"""
Estimate token count for a text segment.
Uses SentencePiece tokenizer for accurate count.
"""
if sp is None:
# Fallback: rough estimate based on words (avg 1.3 tokens per word)
return int(len(text.split()) * 1.3)
return len(sp.encode(text, out_type=str))
def merge_short_sentences(sentences: List[str], max_tokens: int = MAX_TOKENS_PER_CHUNK) -> List[str]:
"""
Merge short consecutive sentences into chunks that fit within token limit.
This improves translation quality by providing more context.
"""
if not sentences:
return []
chunks = []
current_chunk = []
current_tokens = 0
for sentence in sentences:
if sentence == '\n':
# Preserve paragraph breaks
if current_chunk:
chunks.append(' '.join(current_chunk))
current_chunk = []
current_tokens = 0
chunks.append('\n')
continue
sentence_tokens = estimate_token_count(sentence)
# If single sentence exceeds limit, add it as its own chunk
if sentence_tokens > max_tokens:
if current_chunk:
chunks.append(' '.join(current_chunk))
current_chunk = []
current_tokens = 0
chunks.append(sentence)
continue
# If adding this sentence would exceed limit, start new chunk
if current_tokens + sentence_tokens > max_tokens and current_chunk:
chunks.append(' '.join(current_chunk))
current_chunk = [sentence]
current_tokens = sentence_tokens
else:
current_chunk.append(sentence)
current_tokens += sentence_tokens
# Don't forget the last chunk
if current_chunk:
chunks.append(' '.join(current_chunk))
return chunks
# ---------------------------------------------------------
# CORE TRANSLATION LOGIC (RAM ONLY - NO DISK WRITES)
# ---------------------------------------------------------
def _translate_single_chunk_sync(text: str, src: str, tgt: str) -> str:
"""
Synchronous translation of a single chunk.
Used for non-streaming translation of document chunks.
"""
if not text.strip() or text == '\n':
return text
try:
tokens = sp.encode(text, out_type=str)
tokens.insert(0, src)
tokens.append("</s>")
results = translator.translate_batch(
[tokens],
target_prefix=[[tgt]],
max_decoding_length=512,
beam_size=3, # Higher quality
repetition_penalty=1.2
)
out_tokens = results[0].hypotheses[0]
if out_tokens and out_tokens[0] == tgt:
out_tokens = out_tokens[1:]
decoded = sp.decode(out_tokens)
decoded = decoded.replace("▁", " ")
while " " in decoded:
decoded = decoded.replace(" ", " ")
return decoded.strip()
except Exception as e:
log.error(f"Chunk translation error: {e}")
return f"[Translation Error: {str(e)[:50]}]"
def _run_translation_sync(text: str, src: str, tgt: str, token_queue: queue.Queue):
"""
Blocking worker running in thread.
PURE IN-MEMORY OPERATION.
"""
try:
# 1. Tokenize (In-Memory)
source_tokens = sp.encode(text, out_type=str)
source_tokens.insert(0, src)
source_tokens.append("</s>")
# 2. Callback for tokens
def callback(step_result):
token = step_result.token
token_queue.put(token)
return False # Continue generation
# 3. Inference (In-Memory via CTranslate2)
translator.translate_batch(
[source_tokens],
target_prefix=[[tgt]],
max_decoding_length=512,
beam_size=1, # Greedy for speed
repetition_penalty=1.2,
callback=callback
)
except Exception as e:
token_queue.put(f"ERROR: {e}")
finally:
token_queue.put(None) # Sentinel to stop
async def translate_stream_chunk(text: str, src: str, tgt: str) -> AsyncGenerator[str, None]:
"""Async generator that consumes the threaded queue."""
token_queue = queue.Queue()
loop = asyncio.get_running_loop()
# Start blocking job in thread
loop.run_in_executor(
None,
partial(_run_translation_sync, text, src, tgt, token_queue)
)
# Track previous token to handle spaces properly
prev_token_ended_with_space = False
while True:
token = await loop.run_in_executor(None, token_queue.get)
if token is None:
break
if token.startswith("ERROR:"):
yield token
break
# --- SPACE FIX: Proper handling of SentencePiece space markers ---
# SentencePiece uses U+2581 (▁) to denote spaces between words
# We need to handle this properly
# Skip special tokens
if token in [src, tgt, "</s>", "<s>", "<pad>"]:
continue
# Check if token is a space marker
if token == "▁":
# This is a standalone space token
yield " "
prev_token_ended_with_space = True
continue
# Check if token starts with space marker
if token.startswith("▁"):
# Token starts with space marker
# Remove the marker for decoding
clean_token = token[1:] if len(token) > 1 else ""
if clean_token:
# Decode the clean token
decoded_word = sp.decode([clean_token])
# If previous token didn't end with space, add one
if not prev_token_ended_with_space:
yield " " + decoded_word
else:
yield decoded_word
prev_token_ended_with_space = False
else:
# Token was just the space marker
yield " "
prev_token_ended_with_space = True
else:
# Regular token without space marker
decoded_word = sp.decode([token])
# If previous token ended with space, don't add another
if prev_token_ended_with_space:
yield decoded_word
else:
# Check if decoded word starts with space (some tokens might decode with spaces)
if decoded_word.startswith(" "):
yield decoded_word
else:
# No space needed before this word
yield decoded_word
prev_token_ended_with_space = False
async def translate_long_document_stream(text: str, src: str, tgt: str) -> AsyncGenerator[str, None]:
"""
Translate long documents by chunking into sentences.
Streams results chunk by chunk for responsive UI.
"""
# Split text into manageable chunks
sentences = split_into_sentences(text)
chunks = merge_short_sentences(sentences, MAX_TOKENS_PER_CHUNK)
log.info(f"Long document: {len(text)} chars -> {len(chunks)} chunks")
for chunk in chunks:
if chunk == '\n':
yield '\n'
continue
if not chunk.strip():
continue
# Translate this chunk using streaming
async for token in translate_stream_chunk(chunk, src, tgt):
yield token
# Add space between chunks (but not after newlines)
yield " "
async def translate_long_document_sync(text: str, src: str, tgt: str) -> str:
"""
Translate long documents synchronously (for /translate_sync endpoint).
Returns complete translated text.
"""
loop = asyncio.get_running_loop()
# Split text into manageable chunks
sentences = split_into_sentences(text)
chunks = merge_short_sentences(sentences, MAX_TOKENS_PER_CHUNK)
log.info(f"Long document sync: {len(text)} chars -> {len(chunks)} chunks")
translated_parts = []
for chunk in chunks:
if chunk == '\n':
translated_parts.append('\n')
continue
if not chunk.strip():
continue
# Run translation in thread to not block event loop
translated = await loop.run_in_executor(
None,
partial(_translate_single_chunk_sync, chunk, src, tgt)
)
translated_parts.append(translated)
# Join with spaces, but handle newlines properly
result = []
for i, part in enumerate(translated_parts):
if part == '\n':
result.append('\n')
else:
if result and result[-1] != '\n':
result.append(' ')
result.append(part)
return ''.join(result).strip()
# ---------------------------------------------------------
# STARTUP & SHUTDOWN (Modern Lifespan Pattern)
# ---------------------------------------------------------
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Modern lifespan context manager for startup/shutdown."""
global translator, sp, translation_semaphore
try:
await ensure_model()
cpu_count = os.cpu_count() or 4
# Optimal threading configuration for concurrent requests:
# - inter_threads: enables batch-level parallelism (process multiple requests)
# - intra_threads: threads per single operation
# Rule: inter_threads * intra_threads <= cpu_count
inter_threads = max(1, cpu_count // 2) # Parallel batch processing
intra_threads = 2 # Per-operation threads
log.info(f"CPU cores: {cpu_count}, inter_threads: {inter_threads}, intra_threads: {intra_threads}")
# Load Model with optimized threading
translator = ct2.Translator(
str(MODEL_CACHE_DIR),
device="cpu",
compute_type="float32",
inter_threads=inter_threads,
intra_threads=intra_threads
)
sp = spm.SentencePieceProcessor(model_file=str(SP_MODEL_FILE))
# Initialize concurrency semaphore
translation_semaphore = asyncio.Semaphore(MAX_CONCURRENT_TRANSLATIONS)
log.info(f"Concurrency limit: {MAX_CONCURRENT_TRANSLATIONS} simultaneous translations")
# --- WARMUP START ---
log.info("Warming up model...")
warmup_src = "Hello world"
warmup_tokens = sp.encode(warmup_src, out_type=str)
warmup_tokens.insert(0, "eng_Latn")
warmup_tokens.append("</s>")
# Run a silent inference to load weights into RAM
translator.translate_batch(
[warmup_tokens],
target_prefix=[["hin_Deva"]],
max_decoding_length=10,
beam_size=1
)
log.info("Model Warmed Up & Ready!")
# --- WARMUP END ---
except Exception as e:
log.error(f"Fatal startup: {e}")
raise
yield # Application runs here
# Shutdown cleanup (if needed)
log.info("Shutting down NLLB translator...")
app = FastAPI(
title="NLLB Realtime Translator",
version="2.2", # Version bump for new features
docs_url="/docs",
lifespan=lifespan
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ---------------------------------------------------------
# ENDPOINTS
# ---------------------------------------------------------
@app.post("/translate")
async def translate_stream_http(req: TranslateRequest):
"""
Streaming translation endpoint.
Handles long documents via sentence chunking.
Limited by semaphore to prevent overload.
Now uses Server-Sent Events (SSE) to bypass proxy buffering.
"""
import json
async def event_generator():
try:
async with asyncio.timeout(TRANSLATION_TIMEOUT_SECONDS):
async with translation_semaphore:
# Use long document handler for all requests
async for token in translate_long_document_stream(req.text, req.src_lang, req.tgt_lang):
yield f"data: {json.dumps({'token': token})}\n\n"
except asyncio.TimeoutError:
yield f"data: {json.dumps({'error': f'Translation timed out after {TRANSLATION_TIMEOUT_SECONDS} seconds'})}\n\n"
except Exception as e:
yield f"data: {json.dumps({'error': str(e)})}\n\n"
finally:
yield "data: [DONE]\n\n"
return StreamingResponse(event_generator(), media_type="text/event-stream")
@app.websocket("/ws/translate")
async def ws_translate(websocket: WebSocket):
"""
WebSocket streaming translation.
Handles long documents via sentence chunking.
"""
await websocket.accept()
try:
data = await websocket.receive_json()
req = TranslateRequest(**data)
async with asyncio.timeout(TRANSLATION_TIMEOUT_SECONDS):
async with translation_semaphore:
async for token in translate_long_document_stream(req.text, req.src_lang, req.tgt_lang):
await websocket.send_text(token)
except asyncio.TimeoutError:
await websocket.send_text(f"\n[ERROR: Translation timed out after {TRANSLATION_TIMEOUT_SECONDS} seconds]")
except (WebSocketDisconnect, ValueError) as e:
log.info(f"WS Disconnect/Error: {e}")
except Exception as e:
await websocket.send_text(f"error:{str(e)}")
@app.post("/translate_sync")
async def translate_sync(req: TranslateRequest):
"""
Synchronous translation endpoint (returns complete result).
Handles long documents via sentence chunking.
Limited by semaphore and timeout.
"""
try:
async with asyncio.timeout(TRANSLATION_TIMEOUT_SECONDS):
async with translation_semaphore:
translated = await translate_long_document_sync(req.text, req.src_lang, req.tgt_lang)
return JSONResponse({"translation": translated})
except asyncio.TimeoutError:
raise HTTPException(408, f"Translation timed out after {TRANSLATION_TIMEOUT_SECONDS} seconds")
except Exception as e:
raise HTTPException(500, str(e))
@app.get("/health")
async def health():
"""Health check endpoint with detailed status."""
healthy = translator is not None and sp is not None
return {
"status": "healthy" if healthy else "model unavailable",
"type": "float32",
"max_concurrent": MAX_CONCURRENT_TRANSLATIONS,
"timeout_seconds": TRANSLATION_TIMEOUT_SECONDS,
"version": "2.2"
}
if __name__ == "__main__":
import uvicorn
port = int(os.getenv("PORT", 7860))
uvicorn.run("app:app", host="0.0.0.0", port=port, reload=False)