|
|
import os |
|
|
|
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
|
|
from fastapi import FastAPI |
|
|
from pydantic import BaseModel |
|
|
from typing import List |
|
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
|
|
|
MODEL_NAME = "facebook/nllb-200-distilled-600M" |
|
|
print("Loading NLLB Model...") |
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME) |
|
|
translator = pipeline("translation", model=model, tokenizer=tokenizer, max_length=512, device=-1) |
|
|
print("Model Ready!") |
|
|
|
|
|
LANG_MAP = { |
|
|
"Arabic": "arb_Arab", |
|
|
"English": "eng_Latn", |
|
|
"French": "fra_Latn" |
|
|
} |
|
|
|
|
|
class BatchRequest(BaseModel): |
|
|
texts: List[str] |
|
|
source_lang: str |
|
|
target_lang: str |
|
|
|
|
|
@app.get("/") |
|
|
def home(): |
|
|
return {"status": "Running"} |
|
|
|
|
|
@app.post("/translate_batch") |
|
|
def translate_batch(req: BatchRequest): |
|
|
|
|
|
if not req.texts: |
|
|
return {"translations": []} |
|
|
|
|
|
src = LANG_MAP.get(req.source_lang, "eng_Latn") |
|
|
tgt = LANG_MAP.get(req.target_lang, "arb_Arab") |
|
|
|
|
|
|
|
|
try: |
|
|
results = translator(req.texts, src_lang=src, tgt_lang=tgt, batch_size=16) |
|
|
translations = [res['translation_text'] for res in results] |
|
|
return {"translations": translations} |
|
|
except Exception as e: |
|
|
print(f"Error: {e}") |
|
|
|
|
|
return {"translations": req.texts} |
|
|
|