ramt-labse-api / app.py
mytranslatenisa's picture
Update app.py
71cfd7a verified
from fastapi import FastAPI, Query
from pydantic import BaseModel
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer
app = FastAPI()
# Model name on HF
MODEL_NAME = "mytranslatenisa/ramt-labse"
# Load model
tokenizer = M2M100Tokenizer.from_pretrained(MODEL_NAME)
model = M2M100ForConditionalGeneration.from_pretrained(MODEL_NAME)
# Language code mapping (custom)
LANG_CODES = {
"nen": "en",
"nms": "ms"
}
class TranslateRequest(BaseModel):
srcl: str
tgtl: str
text: str
def translate_text(srcl, tgtl, text):
# Check valid codes
if srcl not in LANG_CODES or tgtl not in LANG_CODES:
return {
"translation": {
"src_tokenized": [text],
"text": "",
"code": "1"
}
}
source_lang = LANG_CODES[srcl]
target_lang = LANG_CODES[tgtl]
tokenizer.src_lang = source_lang
encoded = tokenizer(text, return_tensors="pt")
output = model.generate(**encoded, forced_bos_token_id=tokenizer.get_lang_id(target_lang))
translation = tokenizer.batch_decode(output, skip_special_tokens=True)[0]
return {
"translation": {
"src_tokenized": [text],
"text": translation,
"code": "0"
}
}
# -----------------------------
# POST METHOD
# -----------------------------
@app.post("/translate")
def translate_post(req: TranslateRequest):
return translate_text(req.srcl, req.tgtl, req.text)
# -----------------------------
# GET METHOD
# -----------------------------
@app.get("/translate")
def translate_get(
srcl: str = Query(...),
tgtl: str = Query(...),
text: str = Query(...)
):
return translate_text(srcl, tgtl, text)