nmt-translator / app.py
Soha368's picture
NMT Translator - English to Spanish
a888ccc
#!/usr/bin/env python3
"""
NMT Translator — Hugging Face Spaces edition.
Self-contained Flask app that loads JoeyNMT models and serves translations.
"""
import os
import sys
import re
import subprocess
import tempfile
from pathlib import Path
import yaml
from flask import Flask, render_template, request, jsonify
app = Flask(__name__)
BASE_DIR = Path(__file__).resolve().parent
LANGUAGE_NAMES = {
"en": "English", "es": "Spanish", "ru": "Russian",
"bn": "Bangla", "zh": "Chinese",
}
LANGUAGE_FLAGS = {
"en": "\U0001f1ec\U0001f1e7", "es": "\U0001f1ea\U0001f1f8",
"ru": "\U0001f1f7\U0001f1fa", "bn": "\U0001f1e7\U0001f1e9",
"zh": "\U0001f1e8\U0001f1f3",
}
MODEL_REGISTRY = {}
def _discover_models():
"""Find usable models. Each model needs a subdirectory under models/
with: best.ckpt, src_vocab.txt, trg_vocab.txt, config.yaml, bpe.codes."""
models_root = BASE_DIR / "models"
if not models_root.is_dir():
print("WARNING: models/ directory not found")
return
for subdir in sorted(models_root.iterdir()):
if not subdir.is_dir():
continue
ckpt = subdir / "best.ckpt"
config = subdir / "config.yaml"
src_vocab = subdir / "src_vocab.txt"
trg_vocab = subdir / "trg_vocab.txt"
bpe_codes = subdir / "bpe.codes"
if not all(p.exists() for p in [ckpt, config, src_vocab, trg_vocab, bpe_codes]):
missing = [p.name for p in [ckpt, config, src_vocab, trg_vocab, bpe_codes] if not p.exists()]
print(f" Skipping {subdir.name}: missing {missing}")
continue
with open(config, "r", encoding="utf-8") as f:
cfg = yaml.safe_load(f)
src_lang = cfg.get("data", {}).get("src", {}).get("lang", "?")
trg_lang = cfg.get("data", {}).get("trg", {}).get("lang", "?")
pair_key = f"{src_lang}-{trg_lang}"
MODEL_REGISTRY[pair_key] = {
"src_lang": src_lang,
"trg_lang": trg_lang,
"src_name": LANGUAGE_NAMES.get(src_lang, src_lang),
"trg_name": LANGUAGE_NAMES.get(trg_lang, trg_lang),
"src_flag": LANGUAGE_FLAGS.get(src_lang, ""),
"trg_flag": LANGUAGE_FLAGS.get(trg_lang, ""),
"model_dir": str(subdir),
"config_path": str(config),
}
print(f" [{pair_key}] {LANGUAGE_NAMES.get(src_lang, src_lang)} -> "
f"{LANGUAGE_NAMES.get(trg_lang, trg_lang)}")
def _translate_text(pair_key, text):
"""Translate using JoeyNMT's translate mode via the wrapper script."""
if pair_key not in MODEL_REGISTRY:
return None, f"No model for {pair_key}"
info = MODEL_REGISTRY[pair_key]
model_dir = info["model_dir"]
wrapper = BASE_DIR / "joeynmt_wrapper.py"
lines = [l.strip() for l in text.strip().split("\n") if l.strip()]
if not lines:
return None, "No input text"
try:
env = os.environ.copy()
env["PYTHONIOENCODING"] = "utf-8"
env["PYTHONUNBUFFERED"] = "1"
result = subprocess.run(
[sys.executable, str(wrapper), "translate", info["config_path"]],
input="\n".join(lines) + "\n",
capture_output=True, text=True,
env=env, cwd=model_dir, timeout=180,
)
if result.returncode != 0:
return None, f"JoeyNMT error: {result.stderr.strip()[-500:]}"
raw = result.stdout.strip().split("\n")
translations = []
for line in raw:
line = line.strip()
if not line:
continue
if any(line.startswith(p) for p in [
">", "JoeyNMT", "Loading", "The ", "Use cuda",
"WARNING", "INFO", "DEBUG",
]):
continue
translations.append(line)
if not translations:
return None, f"No output. stderr: {result.stderr[:500]}"
out = "\n".join(translations)
out = re.sub(r'\s+([.!?;:,"\'\)\]])', r'\1', out)
out = re.sub(r'([¿¡\(\["])\s+', r'\1', out)
return out.strip(), None
except subprocess.TimeoutExpired:
return None, "Translation timed out"
except Exception as e:
return None, str(e)
@app.route("/")
def index():
pairs = []
for key, info in sorted(MODEL_REGISTRY.items()):
pairs.append({
"key": key,
"src_lang": info["src_lang"], "trg_lang": info["trg_lang"],
"src_name": info["src_name"], "trg_name": info["trg_name"],
"src_flag": info["src_flag"], "trg_flag": info["trg_flag"],
"label": f"{info['src_name']} \u2192 {info['trg_name']}",
})
return render_template("index.html", pairs=pairs)
@app.route("/translate", methods=["POST"])
def translate_endpoint():
data = request.get_json(force=True)
pair_key = data.get("pair", "")
text = data.get("text", "").strip()
if not text:
return jsonify({"error": "Please enter some text."}), 400
if pair_key not in MODEL_REGISTRY:
return jsonify({"error": f"Unknown pair: {pair_key}"}), 400
translation, error = _translate_text(pair_key, text)
if error:
return jsonify({"error": error}), 500
return jsonify({
"translation": translation,
"pair": pair_key,
"src_name": MODEL_REGISTRY[pair_key]["src_name"],
"trg_name": MODEL_REGISTRY[pair_key]["trg_name"],
})
print("\n" + "=" * 50)
print(" NMT Translator — discovering models...")
print("=" * 50)
_discover_models()
print(f"\n {len(MODEL_REGISTRY)} model(s) ready.\n")
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
app.run(host="0.0.0.0", port=port)