Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI, Request | |
| from transformers import ( | |
| MarianMTModel, | |
| MarianTokenizer, | |
| MBartForConditionalGeneration, | |
| MBart50TokenizerFast, | |
| AutoTokenizer, | |
| AutoModelForSeq2SeqLM | |
| ) | |
| import torch | |
| from tokenization_small100 import SMALL100Tokenizer | |
| # import your chunking helpers | |
| from chunking import get_max_word_length, chunk_text | |
| app = FastAPI() | |
| # Map target languages to Hugging Face model IDs | |
| MODEL_MAP = { | |
| "bg": "Helsinki-NLP/opus-mt-tc-big-en-bg", # bulgarian | |
| "cs": "Helsinki-NLP/opus-mt-en-cs", # czech | |
| "da": "Helsinki-NLP/opus-mt-en-da", # danish | |
| "de": "Helsinki-NLP/opus-mt-en-de", # german | |
| "el": "Helsinki-NLP/opus-mt-tc-big-en-el", # greek | |
| "es": "Helsinki-NLP/opus-mt-tc-big-en-es", # spanish | |
| "et": "Helsinki-NLP/opus-mt-tc-big-en-et", # estonian | |
| "fi": "Helsinki-NLP/opus-mt-tc-big-en-fi", # finnish | |
| "fr": "Helsinki-NLP/opus-mt-en-fr", # french | |
| "hr": "facebook/mbart-large-50-many-to-many-mmt", # croatian | |
| "hu": "Helsinki-NLP/opus-mt-tc-big-en-hu", # hungarian | |
| "is": "mkorada/opus-mt-en-is-finetuned-v4", # icelandic # Manas's fine-tuned model | |
| "it": "Helsinki-NLP/opus-mt-tc-big-en-it", # italian | |
| "lb": "alirezamsh/small100", # luxembourgish # small100 | |
| "lt": "Helsinki-NLP/opus-mt-tc-big-en-lt", # lithuanian | |
| "lv": "facebook/mbart-large-50-many-to-many-mmt", # latvian | |
| "me": "Helsinki-NLP/opus-mt-tc-base-en-sh", # montegrin | |
| "mk": "Helsinki-NLP/opus-mt-en-mk", # macedonian | |
| # "nb": "facebook/mbart-large-50-many-to-many-mmt", # norwegian | |
| "nl": "facebook/mbart-large-50-many-to-many-mmt", # dutch | |
| "no": "Confused404/eng-gmq-finetuned_v2-no", # norwegian # Alex's fine-tuned model | |
| "pl": "Helsinki-NLP/opus-mt-en-sla", # polish | |
| "pt": "facebook/mbart-large-50-many-to-many-mmt", # portuguese | |
| "ro": "facebook/mbart-large-50-many-to-many-mmt", # romanian | |
| "sk": "Helsinki-NLP/opus-mt-en-sk", # slovak | |
| "sl": "alirezamsh/small100", # slovene | |
| "sq": "alirezamsh/small100", # albanian | |
| "sv": "Helsinki-NLP/opus-mt-en-sv", # swedish | |
| "tr": "Helsinki-NLP/opus-mt-tc-big-en-tr" # turkish | |
| } | |
| # Cache loaded models/tokenizers | |
| MODEL_CACHE = {} | |
| def load_model(model_id: str, target_lang: str): | |
| """ | |
| Load & cache: | |
| - facebook/mbart-* via MBart50TokenizerFast & MBartForConditionalGeneration | |
| - alirezamsh/small100 via AutoTokenizer & AutoModelForSeq2SeqLM | |
| - all others via MarianTokenizer & MarianMTModel | |
| """ | |
| if model_id not in MODEL_CACHE: | |
| if model_id.startswith("facebook/mbart"): | |
| tokenizer = MBart50TokenizerFast.from_pretrained(model_id) | |
| model = MBartForConditionalGeneration.from_pretrained(model_id) | |
| elif model_id == "alirezamsh/small100": | |
| tokenizer = SMALL100Tokenizer.from_pretrained(model_id, tgt_lang=target_lang) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_id) | |
| else: | |
| tokenizer = MarianTokenizer.from_pretrained(model_id) | |
| model = MarianMTModel.from_pretrained(model_id) | |
| model.to("cpu") | |
| MODEL_CACHE[model_id] = (tokenizer, model) | |
| return MODEL_CACHE[model_id] | |
| async def translate(request: Request): | |
| payload = await request.json() | |
| text = payload.get("text") | |
| target_lang = payload.get("target_lang") | |
| if not text or not target_lang: | |
| return {"error": "Missing 'text' or 'target_lang'"} | |
| model_id = MODEL_MAP.get(target_lang) | |
| if not model_id: | |
| return {"error": f"No model found for target language '{target_lang}'"} | |
| try: | |
| # chunk to safe length | |
| safe_limit = get_max_word_length([target_lang]) | |
| chunks = chunk_text(text, safe_limit) | |
| tokenizer, model = load_model(model_id, target_lang) | |
| full_translation = [] | |
| for chunk in chunks: | |
| inputs = tokenizer(chunk, return_tensors="pt", padding=True, truncation=True) | |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
| outputs = model.generate(**inputs, num_beams=5, length_penalty=1.2, early_stopping=True) | |
| full_translation.append(tokenizer.decode(outputs[0], skip_special_tokens=True)) | |
| return {"translation": " ".join(full_translation)} | |
| except Exception as e: | |
| return {"error": f"Translation failed: {e}"} | |
| def list_languages(): | |
| return {"supported_languages": list(MODEL_MAP.keys())} | |
| def health(): | |
| return {"status": "ok"} | |
| # Uvicorn startup for local testing | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run("app:app", host="0.0.0.0", port=7860) | |