correct issue with small100
Browse files
app.py
CHANGED
|
@@ -3,7 +3,9 @@ from transformers import (
|
|
| 3 |
MarianMTModel,
|
| 4 |
MarianTokenizer,
|
| 5 |
MBartForConditionalGeneration,
|
| 6 |
-
MBart50TokenizerFast
|
|
|
|
|
|
|
| 7 |
)
|
| 8 |
import torch
|
| 9 |
|
|
@@ -25,16 +27,16 @@ MODEL_MAP = {
|
|
| 25 |
"fr": "Helsinki-NLP/opus-mt-en-fr",
|
| 26 |
"hr": "facebook/mbart-large-50-many-to-many-mmt",
|
| 27 |
"hu": "Helsinki-NLP/opus-mt-tc-big-en-hu",
|
| 28 |
-
"is": "mkorada/opus-mt-en-is-finetuned-v4",
|
| 29 |
"it": "Helsinki-NLP/opus-mt-tc-big-en-it",
|
| 30 |
-
"lb": "alirezamsh/small100",
|
| 31 |
"lt": "Helsinki-NLP/opus-mt-tc-big-en-lt",
|
| 32 |
"lv": "facebook/mbart-large-50-many-to-many-mmt",
|
| 33 |
-
"me": "Helsinki-NLP/opus-mt-tc-base-en-sh"
|
| 34 |
"mk": "Helsinki-NLP/opus-mt-en-mk",
|
| 35 |
-
"nb": "facebook/mbart-large-50-many-to-many-mmt",
|
| 36 |
"nl": "facebook/mbart-large-50-many-to-many-mmt",
|
| 37 |
-
"no": "Confused404/eng-gmq-finetuned_v2-no", #Alex's fine-tuned model
|
| 38 |
"pl": "Helsinki-NLP/opus-mt-en-sla",
|
| 39 |
"pt": "facebook/mbart-large-50-many-to-many-mmt",
|
| 40 |
"ro": "facebook/mbart-large-50-many-to-many-mmt",
|
|
@@ -45,29 +47,32 @@ MODEL_MAP = {
|
|
| 45 |
"tr": "Helsinki-NLP/opus-mt-tc-big-en-tr"
|
| 46 |
}
|
| 47 |
|
| 48 |
-
|
| 49 |
MODEL_CACHE = {}
|
| 50 |
|
| 51 |
-
# ✅ Load Hugging Face model (Helsinki or Small100)
|
| 52 |
def load_model(model_id: str):
|
| 53 |
"""
|
| 54 |
-
Load & cache
|
| 55 |
-
-
|
| 56 |
-
-
|
|
|
|
| 57 |
"""
|
| 58 |
if model_id not in MODEL_CACHE:
|
| 59 |
if model_id.startswith("facebook/mbart"):
|
| 60 |
tokenizer = MBart50TokenizerFast.from_pretrained(model_id)
|
| 61 |
model = MBartForConditionalGeneration.from_pretrained(model_id)
|
|
|
|
|
|
|
|
|
|
| 62 |
else:
|
| 63 |
tokenizer = MarianTokenizer.from_pretrained(model_id)
|
| 64 |
model = MarianMTModel.from_pretrained(model_id)
|
|
|
|
| 65 |
model.to("cpu")
|
| 66 |
MODEL_CACHE[model_id] = (tokenizer, model)
|
| 67 |
-
return MODEL_CACHE[model_id]
|
| 68 |
|
|
|
|
| 69 |
|
| 70 |
-
# ✅ POST /translate
|
| 71 |
@app.post("/translate")
|
| 72 |
async def translate(request: Request):
|
| 73 |
payload = await request.json()
|
|
@@ -100,18 +105,15 @@ async def translate(request: Request):
|
|
| 100 |
except Exception as e:
|
| 101 |
return {"error": f"Translation failed: {e}"}
|
| 102 |
|
| 103 |
-
|
| 104 |
-
# ✅ GET /languages
|
| 105 |
@app.get("/languages")
|
| 106 |
def list_languages():
|
| 107 |
return {"supported_languages": list(MODEL_MAP.keys())}
|
| 108 |
|
| 109 |
-
# ✅ GET /health
|
| 110 |
@app.get("/health")
|
| 111 |
def health():
|
| 112 |
return {"status": "ok"}
|
| 113 |
|
| 114 |
-
#
|
| 115 |
-
import uvicorn
|
| 116 |
if __name__ == "__main__":
|
| 117 |
-
uvicorn
|
|
|
|
|
|
| 3 |
MarianMTModel,
|
| 4 |
MarianTokenizer,
|
| 5 |
MBartForConditionalGeneration,
|
| 6 |
+
MBart50TokenizerFast,
|
| 7 |
+
AutoTokenizer,
|
| 8 |
+
AutoModelForSeq2SeqLM
|
| 9 |
)
|
| 10 |
import torch
|
| 11 |
|
|
|
|
| 27 |
"fr": "Helsinki-NLP/opus-mt-en-fr",
|
| 28 |
"hr": "facebook/mbart-large-50-many-to-many-mmt",
|
| 29 |
"hu": "Helsinki-NLP/opus-mt-tc-big-en-hu",
|
| 30 |
+
"is": "mkorada/opus-mt-en-is-finetuned-v4", # Manas's fine-tuned model
|
| 31 |
"it": "Helsinki-NLP/opus-mt-tc-big-en-it",
|
| 32 |
+
"lb": "alirezamsh/small100", # small100
|
| 33 |
"lt": "Helsinki-NLP/opus-mt-tc-big-en-lt",
|
| 34 |
"lv": "facebook/mbart-large-50-many-to-many-mmt",
|
| 35 |
+
"me": "Helsinki-NLP/opus-mt-tc-base-en-sh",
|
| 36 |
"mk": "Helsinki-NLP/opus-mt-en-mk",
|
| 37 |
+
"nb": "facebook/mbart-large-50-many-to-many-mmt",
|
| 38 |
"nl": "facebook/mbart-large-50-many-to-many-mmt",
|
| 39 |
+
"no": "Confused404/eng-gmq-finetuned_v2-no", # Alex's fine-tuned model
|
| 40 |
"pl": "Helsinki-NLP/opus-mt-en-sla",
|
| 41 |
"pt": "facebook/mbart-large-50-many-to-many-mmt",
|
| 42 |
"ro": "facebook/mbart-large-50-many-to-many-mmt",
|
|
|
|
| 47 |
"tr": "Helsinki-NLP/opus-mt-tc-big-en-tr"
|
| 48 |
}
|
| 49 |
|
| 50 |
+
# Cache loaded models/tokenizers
|
| 51 |
MODEL_CACHE = {}
|
| 52 |
|
|
|
|
| 53 |
def load_model(model_id: str):
|
| 54 |
"""
|
| 55 |
+
Load & cache:
|
| 56 |
+
- facebook/mbart-* via MBart50TokenizerFast & MBartForConditionalGeneration
|
| 57 |
+
- alirezamsh/small100 via AutoTokenizer & AutoModelForSeq2SeqLM
|
| 58 |
+
- all others via MarianTokenizer & MarianMTModel
|
| 59 |
"""
|
| 60 |
if model_id not in MODEL_CACHE:
|
| 61 |
if model_id.startswith("facebook/mbart"):
|
| 62 |
tokenizer = MBart50TokenizerFast.from_pretrained(model_id)
|
| 63 |
model = MBartForConditionalGeneration.from_pretrained(model_id)
|
| 64 |
+
elif model_id == "alirezamsh/small100":
|
| 65 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| 66 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
|
| 67 |
else:
|
| 68 |
tokenizer = MarianTokenizer.from_pretrained(model_id)
|
| 69 |
model = MarianMTModel.from_pretrained(model_id)
|
| 70 |
+
|
| 71 |
model.to("cpu")
|
| 72 |
MODEL_CACHE[model_id] = (tokenizer, model)
|
|
|
|
| 73 |
|
| 74 |
+
return MODEL_CACHE[model_id]
|
| 75 |
|
|
|
|
| 76 |
@app.post("/translate")
|
| 77 |
async def translate(request: Request):
|
| 78 |
payload = await request.json()
|
|
|
|
| 105 |
except Exception as e:
|
| 106 |
return {"error": f"Translation failed: {e}"}
|
| 107 |
|
|
|
|
|
|
|
| 108 |
@app.get("/languages")
|
| 109 |
def list_languages():
|
| 110 |
return {"supported_languages": list(MODEL_MAP.keys())}
|
| 111 |
|
|
|
|
| 112 |
@app.get("/health")
|
| 113 |
def health():
|
| 114 |
return {"status": "ok"}
|
| 115 |
|
| 116 |
+
# Uvicorn startup for local testing
|
|
|
|
| 117 |
if __name__ == "__main__":
|
| 118 |
+
import uvicorn
|
| 119 |
+
uvicorn.run("app:app", host="0.0.0.0", port=7860)
|