| from fastapi import FastAPI, Form |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
| import torch |
| import os |
|
|
| app = FastAPI() |
|
|
| MODEL_NAME = "facebook/nllb-200-distilled-600M" |
| SRC_LANG = "mar_Deva" |
| TGT_LANG = "eng_Latn" |
|
|
| tokenizer = None |
| model = None |
|
|
| def load_model(): |
| global tokenizer, model |
| if tokenizer is None or model is None: |
| print("Loading NLLB model... This may take a minute.") |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
| model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME) |
| tokenizer.src_lang = SRC_LANG |
| print("Model loaded successfully!") |
|
|
| @app.post("/translate") |
| async def translate(text: str = Form(...)): |
| load_model() |
| inputs = tokenizer(text, return_tensors="pt") |
| with torch.no_grad(): |
| generated_tokens = model.generate( |
| **inputs, |
| forced_bos_token_id=tokenizer.convert_tokens_to_ids(TGT_LANG), |
| max_length=512 |
| ) |
| translated_text = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] |
| return {"translation": translated_text} |
|
|
| @app.get("/") |
| def root(): |
| return {"status": "NLLB API is running (model loads on first request)!"} |
|
|
| if __name__ == "__main__": |
| import uvicorn |
| port = int(os.environ.get("PORT", 8000)) |
| uvicorn.run(app, host="0.0.0.0", port=port) |
|
|