MedhaCodes's picture
Create app.py
943cbe3 verified
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import os
# Model & cache setup
model_name = "csebuetnlp/banglat5_nmt_en_bn"
CACHE_DIR = os.path.join(os.path.expanduser("~"), ".cache", "huggingface")
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, cache_dir=CACHE_DIR)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, cache_dir=CACHE_DIR)
# Put model in eval mode
model.eval()
app = FastAPI(title="English to Bangla Translation API")
class TranslationRequest(BaseModel):
text: str
@app.post("/translate")
def translate(req: TranslationRequest):
inputs = tokenizer(req.text, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
outputs = model.generate(**inputs, max_length=256)
translated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return {"translation": translated_text}
@app.get("/")
def root():
return {"message": "English → Bangla Translation API is running!"}