indic-en / app.py
pavan10504's picture
Update app.py
d4a6660 verified
from fastapi import FastAPI
from pydantic import BaseModel
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from IndicTransToolkit.processor import IndicProcessor
DEVICE = "cpu" # HF free tier
# -------------------------------
# Models
# -------------------------------
INDIC_EN_MODEL = "ai4bharat/indictrans2-indic-en-1B"
EN_INDIC_MODEL = "ai4bharat/indictrans2-en-indic-1B"
# -------------------------------
# Load Indic β†’ English
# -------------------------------
indic_en_tokenizer = AutoTokenizer.from_pretrained(
INDIC_EN_MODEL,
trust_remote_code=True
)
indic_en_model = AutoModelForSeq2SeqLM.from_pretrained(
INDIC_EN_MODEL,
trust_remote_code=True
).to(DEVICE)
indic_en_model.eval()
# -------------------------------
# Load English β†’ Indic
# -------------------------------
en_indic_tokenizer = AutoTokenizer.from_pretrained(
EN_INDIC_MODEL,
trust_remote_code=True
)
en_indic_model = AutoModelForSeq2SeqLM.from_pretrained(
EN_INDIC_MODEL,
trust_remote_code=True
).to(DEVICE)
en_indic_model.eval()
# -------------------------------
# Processor
# -------------------------------
ip = IndicProcessor(inference=True)
# -------------------------------
# FastAPI app
# -------------------------------
app = FastAPI(
title="Indic ↔ English Translation API",
docs_url="/docs"
)
# -------------------------------
# Schemas
# -------------------------------
class IndicToEnRequest(BaseModel):
text: str
src_lang: str # e.g. kan_Knda, hin_Deva
class EnToIndicRequest(BaseModel):
text: str
tgt_lang: str # e.g. kan_Knda, hin_Deva
# -------------------------------
# Root
# -------------------------------
@app.get("/")
def root():
return {
"status": "ok",
"endpoints": {
"indic_to_en": "/translate/indic-to-en",
"en_to_indic": "/translate/en-to-indic"
}
}
# -------------------------------
# Indic β†’ English
# -------------------------------
@app.post("/translate/indic-to-en")
def indic_to_en(req: IndicToEnRequest):
batch = ip.preprocess_batch(
[req.text],
src_lang=req.src_lang,
tgt_lang="eng_Latn"
)
inputs = indic_en_tokenizer(
batch, return_tensors="pt", padding=True
)
with torch.no_grad():
outputs = indic_en_model.generate(
**inputs,
max_length=128,
num_beams=3,
use_cache=False
)
translation = indic_en_tokenizer.batch_decode(
outputs, skip_special_tokens=True
)[0]
translation = ip.postprocess_batch(
[translation], "eng_Latn"
)[0]
return {"translation": translation}
# -------------------------------
# English β†’ Indic
# -------------------------------
@app.post("/translate/en-to-indic")
def en_to_indic(req: EnToIndicRequest):
batch = ip.preprocess_batch(
[req.text],
src_lang="eng_Latn",
tgt_lang=req.tgt_lang
)
inputs = en_indic_tokenizer(
batch, return_tensors="pt", padding=True
)
with torch.no_grad():
outputs = en_indic_model.generate(
**inputs,
max_length=128,
num_beams=3,
use_cache=False
)
translation = en_indic_tokenizer.batch_decode(
outputs, skip_special_tokens=True
)[0]
translation = ip.postprocess_batch(
[translation], req.tgt_lang
)[0]
return {"translation": translation}