Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, Query | |
| from pydantic import BaseModel | |
| from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer | |
| app = FastAPI() | |
| # Model name on HF | |
| MODEL_NAME = "mytranslatenisa/ramt-labse" | |
| # Load model | |
| tokenizer = M2M100Tokenizer.from_pretrained(MODEL_NAME) | |
| model = M2M100ForConditionalGeneration.from_pretrained(MODEL_NAME) | |
| # Language code mapping (custom) | |
| LANG_CODES = { | |
| "nen": "en", | |
| "nms": "ms" | |
| } | |
| class TranslateRequest(BaseModel): | |
| srcl: str | |
| tgtl: str | |
| text: str | |
| def translate_text(srcl, tgtl, text): | |
| # Check valid codes | |
| if srcl not in LANG_CODES or tgtl not in LANG_CODES: | |
| return { | |
| "translation": { | |
| "src_tokenized": [text], | |
| "text": "", | |
| "code": "1" | |
| } | |
| } | |
| source_lang = LANG_CODES[srcl] | |
| target_lang = LANG_CODES[tgtl] | |
| tokenizer.src_lang = source_lang | |
| encoded = tokenizer(text, return_tensors="pt") | |
| output = model.generate(**encoded, forced_bos_token_id=tokenizer.get_lang_id(target_lang)) | |
| translation = tokenizer.batch_decode(output, skip_special_tokens=True)[0] | |
| return { | |
| "translation": { | |
| "src_tokenized": [text], | |
| "text": translation, | |
| "code": "0" | |
| } | |
| } | |
| # ----------------------------- | |
| # POST METHOD | |
| # ----------------------------- | |
| def translate_post(req: TranslateRequest): | |
| return translate_text(req.srcl, req.tgtl, req.text) | |
| # ----------------------------- | |
| # GET METHOD | |
| # ----------------------------- | |
| def translate_get( | |
| srcl: str = Query(...), | |
| tgtl: str = Query(...), | |
| text: str = Query(...) | |
| ): | |
| return translate_text(srcl, tgtl, text) | |