Update app.py
Browse files
app.py
CHANGED
|
@@ -4,13 +4,41 @@ import torch
|
|
| 4 |
|
| 5 |
app = FastAPI()
|
| 6 |
|
|
|
|
| 7 |
MODEL_MAP = {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
"fr": "Helsinki-NLP/opus-mt-en-fr",
|
| 9 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
}
|
| 11 |
|
|
|
|
| 12 |
MODEL_CACHE = {}
|
| 13 |
|
|
|
|
| 14 |
def load_model(model_id):
|
| 15 |
if model_id not in MODEL_CACHE:
|
| 16 |
tokenizer = MarianTokenizer.from_pretrained(model_id)
|
|
@@ -18,6 +46,7 @@ def load_model(model_id):
|
|
| 18 |
MODEL_CACHE[model_id] = (tokenizer, model)
|
| 19 |
return MODEL_CACHE[model_id]
|
| 20 |
|
|
|
|
| 21 |
@app.post("/translate")
|
| 22 |
async def translate(request: Request):
|
| 23 |
data = await request.json()
|
|
@@ -25,18 +54,34 @@ async def translate(request: Request):
|
|
| 25 |
target_lang = data.get("target_lang")
|
| 26 |
|
| 27 |
if not text or not target_lang:
|
| 28 |
-
return {"error": "Missing text or target_lang"}
|
| 29 |
|
| 30 |
model_id = MODEL_MAP.get(target_lang)
|
| 31 |
if not model_id:
|
| 32 |
-
return {"error": f"No model for '{target_lang}'"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
return {"
|
| 38 |
|
| 39 |
-
#
|
| 40 |
import uvicorn
|
| 41 |
if __name__ == "__main__":
|
| 42 |
-
uvicorn.run("app:app", host="0.0.0.0", port=7860)
|
|
|
|
| 4 |
|
| 5 |
app = FastAPI()
|
| 6 |
|
| 7 |
+
# Map target languages to Hugging Face model IDs
|
| 8 |
MODEL_MAP = {
|
| 9 |
+
"bg": "Helsinki-NLP/opus-mt-tc-big-en-bg",
|
| 10 |
+
"cs": "Helsinki-NLP/opus-mt-en-cs",
|
| 11 |
+
"da": "Helsinki-NLP/opus-mt-en-da",
|
| 12 |
+
"de": "Helsinki-NLP/opus-mt-en-de",
|
| 13 |
+
"el": "Helsinki-NLP/opus-mt-tc-big-en-el",
|
| 14 |
+
"es": "facebook/nllb-200-distilled-600M",
|
| 15 |
+
"et": "Helsinki-NLP/opus-mt-tc-big-en-et",
|
| 16 |
+
"fi": "Helsinki-NLP/opus-mt-tc-big-en-fi",
|
| 17 |
"fr": "Helsinki-NLP/opus-mt-en-fr",
|
| 18 |
+
"hr": "facebook/mbart-large-50-many-to-many-mmt",
|
| 19 |
+
"hu": "Helsinki-NLP/opus-mt-tc-big-en-hu",
|
| 20 |
+
"is": "facebook/nllb-200-distilled-600M",
|
| 21 |
+
"it": "facebook/nllb-200-distilled-600M",
|
| 22 |
+
"lt": "Helsinki-NLP/opus-mt-tc-big-en-lt",
|
| 23 |
+
"lv": "facebook/mbart-large-50-many-to-many-mmt",
|
| 24 |
+
"mk": "facebook/nllb-200-distilled-600M",
|
| 25 |
+
"nb": "facebook/mbart-large-50-many-to-many-mmt", #place holder!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
| 26 |
+
"nl": "facebook/mbart-large-50-many-to-many-mmt",
|
| 27 |
+
"no": "facebook/mbart-large-50-many-to-many-mmt", #place holder!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
| 28 |
+
"pl": "facebook/nllb-200-distilled-600M",
|
| 29 |
+
"pt": "facebook/mbart-large-50-many-to-many-mmt",
|
| 30 |
+
"ro": "facebook/mbart-large-50-many-to-many-mmt",
|
| 31 |
+
"sk": "Helsinki-NLP/opus-mt-en-sk",
|
| 32 |
+
"sl": "alirezamsh/small100",
|
| 33 |
+
"sq": "alirezamsh/small100",
|
| 34 |
+
"sv": "Helsinki-NLP/opus-mt-en-sv",
|
| 35 |
+
"tr": "facebook/nllb-200-distilled-600M"
|
| 36 |
}
|
| 37 |
|
| 38 |
+
|
| 39 |
MODEL_CACHE = {}
|
| 40 |
|
| 41 |
+
# ✅ Load Hugging Face model (Helsinki or Small100)
|
| 42 |
def load_model(model_id):
|
| 43 |
if model_id not in MODEL_CACHE:
|
| 44 |
tokenizer = MarianTokenizer.from_pretrained(model_id)
|
|
|
|
| 46 |
MODEL_CACHE[model_id] = (tokenizer, model)
|
| 47 |
return MODEL_CACHE[model_id]
|
| 48 |
|
| 49 |
+
# ✅ POST /translate
|
| 50 |
@app.post("/translate")
|
| 51 |
async def translate(request: Request):
|
| 52 |
data = await request.json()
|
|
|
|
| 54 |
target_lang = data.get("target_lang")
|
| 55 |
|
| 56 |
if not text or not target_lang:
|
| 57 |
+
return {"error": "Missing 'text' or 'target_lang'"}
|
| 58 |
|
| 59 |
model_id = MODEL_MAP.get(target_lang)
|
| 60 |
if not model_id:
|
| 61 |
+
return {"error": f"No model found for target language '{target_lang}'"}
|
| 62 |
+
|
| 63 |
+
if model_id.startswith("facebook/"):
|
| 64 |
+
return {"translation": f"[{target_lang}] uses model '{model_id}', which is not supported in this Space yet."}
|
| 65 |
+
|
| 66 |
+
try:
|
| 67 |
+
tokenizer, model = load_model(model_id)
|
| 68 |
+
inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(model.device)
|
| 69 |
+
outputs = model.generate(**inputs, num_beams=5, length_penalty=1.2, early_stopping=True)
|
| 70 |
+
return {"translation": tokenizer.decode(outputs[0], skip_special_tokens=True)}
|
| 71 |
+
except Exception as e:
|
| 72 |
+
return {"error": f"Translation failed: {str(e)}"}
|
| 73 |
+
|
| 74 |
+
# ✅ GET /languages
|
| 75 |
+
@app.get("/languages")
|
| 76 |
+
def list_languages():
|
| 77 |
+
return {"supported_languages": list(MODEL_MAP.keys())}
|
| 78 |
|
| 79 |
+
# ✅ GET /health
|
| 80 |
+
@app.get("/health")
|
| 81 |
+
def health():
|
| 82 |
+
return {"status": "ok"}
|
| 83 |
|
| 84 |
+
# ✅ Uvicorn startup (required by Hugging Face)
|
| 85 |
import uvicorn
|
| 86 |
if __name__ == "__main__":
|
| 87 |
+
uvicorn.run("app:app", host="0.0.0.0", port=7860)
|