|
|
""" |
|
|
A FastAPI application for serving the translation model, inspired by interactive_translate.py. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
from transformers import M2M100ForConditionalGeneration, NllbTokenizer |
|
|
from fastapi import FastAPI, HTTPException, UploadFile, File |
|
|
from fastapi.staticfiles import StaticFiles |
|
|
from fastapi.responses import FileResponse |
|
|
from pydantic import BaseModel |
|
|
import logging |
|
|
from typing import List |
|
|
import fitz |
|
|
import shutil |
|
|
import os |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
app = FastAPI( |
|
|
title="Saksi Translation API", |
|
|
description="A simple API for translating text and PDFs to English.", |
|
|
version="2.0", |
|
|
) |
|
|
|
|
|
app.mount("/frontend", StaticFiles(directory="frontend"), name="frontend") |
|
|
|
|
|
|
|
|
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
SUPPORTED_LANGUAGES = { |
|
|
"nepali": "nep_Npan", |
|
|
"sinhala": "sin_Sinh", |
|
|
} |
|
|
MODEL_PATH = "models/nllb-finetuned-nepali-en" |
|
|
model = None |
|
|
tokenizer = None |
|
|
|
|
|
|
|
|
class TranslationRequest(BaseModel): |
|
|
text: str |
|
|
source_language: str |
|
|
|
|
|
class TranslationResponse(BaseModel): |
|
|
original_text: str |
|
|
translated_text: str |
|
|
source_language: str |
|
|
|
|
|
class BatchTranslationRequest(BaseModel): |
|
|
texts: List[str] |
|
|
source_language: str |
|
|
|
|
|
class BatchTranslationResponse(BaseModel): |
|
|
original_texts: List[str] |
|
|
translated_texts: List[str] |
|
|
source_language: str |
|
|
|
|
|
class PdfTranslationResponse(BaseModel): |
|
|
filename: str |
|
|
translated_text: str |
|
|
source_language: str |
|
|
|
|
|
|
|
|
|
|
|
def load_model_and_tokenizer(model_path): |
|
|
"""Loads the model and tokenizer from the given path.""" |
|
|
global model, tokenizer |
|
|
logger.info(f"Loading model on {DEVICE.upper()}...") |
|
|
try: |
|
|
model = M2M100ForConditionalGeneration.from_pretrained(model_path).to(DEVICE) |
|
|
tokenizer = NllbTokenizer.from_pretrained(model_path) |
|
|
logger.info("Model and tokenizer loaded successfully!") |
|
|
except Exception as e: |
|
|
logger.error(f"Error loading model: {e}") |
|
|
|
|
|
raise |
|
|
|
|
|
def translate_text(text: str, src_lang: str) -> str: |
|
|
""" |
|
|
Translates a single string of text to English. |
|
|
""" |
|
|
if src_lang not in SUPPORTED_LANGUAGES: |
|
|
raise ValueError(f"Language '{src_lang}' not supported.") |
|
|
|
|
|
tokenizer.src_lang = SUPPORTED_LANGUAGES[src_lang] |
|
|
inputs = tokenizer(text, return_tensors="pt").to(DEVICE) |
|
|
|
|
|
generated_tokens = model.generate( |
|
|
**inputs, |
|
|
forced_bos_token_id=tokenizer.convert_tokens_to_ids("eng_Latn"), |
|
|
max_length=128, |
|
|
) |
|
|
|
|
|
return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] |
|
|
|
|
|
def batch_translate_text(texts: List[str], src_lang: str) -> List[str]: |
|
|
""" |
|
|
Translates a batch of texts to English. |
|
|
""" |
|
|
if src_lang not in SUPPORTED_LANGUAGES: |
|
|
raise ValueError(f"Language '{src_lang}' not supported.") |
|
|
|
|
|
tokenizer.src_lang = SUPPORTED_LANGUAGES[src_lang] |
|
|
|
|
|
inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512).to(DEVICE) |
|
|
|
|
|
generated_tokens = model.generate( |
|
|
**inputs, |
|
|
forced_bos_token_id=tokenizer.convert_tokens_to_ids("eng_Latn"), |
|
|
max_length=512, |
|
|
) |
|
|
|
|
|
return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) |
|
|
|
|
|
|
|
|
@app.on_event("startup") |
|
|
async def startup_event(): |
|
|
"""Load the model at startup.""" |
|
|
load_model_and_tokenizer(MODEL_PATH) |
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
async def root(): |
|
|
"""Returns the frontend.""" |
|
|
return FileResponse('frontend/index.html') |
|
|
|
|
|
@app.get("/languages") |
|
|
def get_supported_languages(): |
|
|
"""Returns a list of supported languages.""" |
|
|
return {"supported_languages": list(SUPPORTED_LANGUAGES.keys())} |
|
|
|
|
|
@app.post("/translate", response_model=TranslationResponse) |
|
|
async def translate(request: TranslationRequest): |
|
|
"""Translates a single text from a source language to English.""" |
|
|
try: |
|
|
translated_text = translate_text(request.text, request.source_language) |
|
|
return TranslationResponse( |
|
|
original_text=request.text, |
|
|
translated_text=translated_text, |
|
|
source_language=request.source_language, |
|
|
) |
|
|
except ValueError as e: |
|
|
raise HTTPException(status_code=400, detail=str(e)) |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {e}") |
|
|
|
|
|
@app.post("/batch-translate", response_model=BatchTranslationResponse) |
|
|
async def batch_translate(request: BatchTranslationRequest): |
|
|
"""Translates a batch of texts from a source language to English.""" |
|
|
try: |
|
|
translated_texts = batch_translate_text(request.texts, request.source_language) |
|
|
return BatchTranslationResponse( |
|
|
original_texts=request.texts, |
|
|
translated_texts=translated_texts, |
|
|
source_language=request.source_language, |
|
|
) |
|
|
except ValueError as e: |
|
|
raise HTTPException(status_code=400, detail=str(e)) |
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {e}") |
|
|
|
|
|
@app.post("/translate-pdf", response_model=PdfTranslationResponse) |
|
|
async def translate_pdf(source_language: str, file: UploadFile = File(...)): |
|
|
"""Translates a PDF file from a source language to English.""" |
|
|
if file.content_type != "application/pdf": |
|
|
raise HTTPException(status_code=400, detail="Invalid file type. Please upload a PDF.") |
|
|
|
|
|
|
|
|
temp_pdf_path = f"temp_{file.filename}" |
|
|
with open(temp_pdf_path, "wb") as buffer: |
|
|
shutil.copyfileobj(file.file, buffer) |
|
|
|
|
|
try: |
|
|
|
|
|
doc = fitz.open(temp_pdf_path) |
|
|
extracted_text = "" |
|
|
for page in doc: |
|
|
extracted_text += page.get_text() |
|
|
doc.close() |
|
|
|
|
|
if not extracted_text.strip(): |
|
|
raise HTTPException(status_code=400, detail="Could not extract any text from the PDF.") |
|
|
|
|
|
|
|
|
text_chunks = [p.strip() for p in extracted_text.split('\n') if p.strip()] |
|
|
|
|
|
|
|
|
translated_chunks = batch_translate_text(text_chunks, source_language) |
|
|
|
|
|
|
|
|
final_translation = "\n".join(translated_chunks) |
|
|
|
|
|
return PdfTranslationResponse( |
|
|
filename=file.filename, |
|
|
translated_text=final_translation, |
|
|
source_language=source_language, |
|
|
) |
|
|
except Exception as e: |
|
|
logger.error(f"Error processing PDF: {e}") |
|
|
raise HTTPException(status_code=500, detail=f"An error occurred while processing the PDF: {e}") |
|
|
finally: |
|
|
|
|
|
if os.path.exists(temp_pdf_path): |
|
|
os.remove(temp_pdf_path) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import uvicorn |
|
|
uvicorn.run(app, host="0.0.0.0", port=8000) |
|
|
|