Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| NMT Translator — Hugging Face Spaces edition. | |
| Self-contained Flask app that loads JoeyNMT models and serves translations. | |
| """ | |
| import os | |
| import sys | |
| import re | |
| import subprocess | |
| import tempfile | |
| from pathlib import Path | |
| import yaml | |
| from flask import Flask, render_template, request, jsonify | |
| app = Flask(__name__) | |
| BASE_DIR = Path(__file__).resolve().parent | |
| LANGUAGE_NAMES = { | |
| "en": "English", "es": "Spanish", "ru": "Russian", | |
| "bn": "Bangla", "zh": "Chinese", | |
| } | |
| LANGUAGE_FLAGS = { | |
| "en": "\U0001f1ec\U0001f1e7", "es": "\U0001f1ea\U0001f1f8", | |
| "ru": "\U0001f1f7\U0001f1fa", "bn": "\U0001f1e7\U0001f1e9", | |
| "zh": "\U0001f1e8\U0001f1f3", | |
| } | |
| MODEL_REGISTRY = {} | |
| def _discover_models(): | |
| """Find usable models. Each model needs a subdirectory under models/ | |
| with: best.ckpt, src_vocab.txt, trg_vocab.txt, config.yaml, bpe.codes.""" | |
| models_root = BASE_DIR / "models" | |
| if not models_root.is_dir(): | |
| print("WARNING: models/ directory not found") | |
| return | |
| for subdir in sorted(models_root.iterdir()): | |
| if not subdir.is_dir(): | |
| continue | |
| ckpt = subdir / "best.ckpt" | |
| config = subdir / "config.yaml" | |
| src_vocab = subdir / "src_vocab.txt" | |
| trg_vocab = subdir / "trg_vocab.txt" | |
| bpe_codes = subdir / "bpe.codes" | |
| if not all(p.exists() for p in [ckpt, config, src_vocab, trg_vocab, bpe_codes]): | |
| missing = [p.name for p in [ckpt, config, src_vocab, trg_vocab, bpe_codes] if not p.exists()] | |
| print(f" Skipping {subdir.name}: missing {missing}") | |
| continue | |
| with open(config, "r", encoding="utf-8") as f: | |
| cfg = yaml.safe_load(f) | |
| src_lang = cfg.get("data", {}).get("src", {}).get("lang", "?") | |
| trg_lang = cfg.get("data", {}).get("trg", {}).get("lang", "?") | |
| pair_key = f"{src_lang}-{trg_lang}" | |
| MODEL_REGISTRY[pair_key] = { | |
| "src_lang": src_lang, | |
| "trg_lang": trg_lang, | |
| "src_name": LANGUAGE_NAMES.get(src_lang, src_lang), | |
| "trg_name": LANGUAGE_NAMES.get(trg_lang, trg_lang), | |
| "src_flag": LANGUAGE_FLAGS.get(src_lang, ""), | |
| "trg_flag": LANGUAGE_FLAGS.get(trg_lang, ""), | |
| "model_dir": str(subdir), | |
| "config_path": str(config), | |
| } | |
| print(f" [{pair_key}] {LANGUAGE_NAMES.get(src_lang, src_lang)} -> " | |
| f"{LANGUAGE_NAMES.get(trg_lang, trg_lang)}") | |
| def _translate_text(pair_key, text): | |
| """Translate using JoeyNMT's translate mode via the wrapper script.""" | |
| if pair_key not in MODEL_REGISTRY: | |
| return None, f"No model for {pair_key}" | |
| info = MODEL_REGISTRY[pair_key] | |
| model_dir = info["model_dir"] | |
| wrapper = BASE_DIR / "joeynmt_wrapper.py" | |
| lines = [l.strip() for l in text.strip().split("\n") if l.strip()] | |
| if not lines: | |
| return None, "No input text" | |
| try: | |
| env = os.environ.copy() | |
| env["PYTHONIOENCODING"] = "utf-8" | |
| env["PYTHONUNBUFFERED"] = "1" | |
| result = subprocess.run( | |
| [sys.executable, str(wrapper), "translate", info["config_path"]], | |
| input="\n".join(lines) + "\n", | |
| capture_output=True, text=True, | |
| env=env, cwd=model_dir, timeout=180, | |
| ) | |
| if result.returncode != 0: | |
| return None, f"JoeyNMT error: {result.stderr.strip()[-500:]}" | |
| raw = result.stdout.strip().split("\n") | |
| translations = [] | |
| for line in raw: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| if any(line.startswith(p) for p in [ | |
| ">", "JoeyNMT", "Loading", "The ", "Use cuda", | |
| "WARNING", "INFO", "DEBUG", | |
| ]): | |
| continue | |
| translations.append(line) | |
| if not translations: | |
| return None, f"No output. stderr: {result.stderr[:500]}" | |
| out = "\n".join(translations) | |
| out = re.sub(r'\s+([.!?;:,"\'\)\]])', r'\1', out) | |
| out = re.sub(r'([¿¡\(\["])\s+', r'\1', out) | |
| return out.strip(), None | |
| except subprocess.TimeoutExpired: | |
| return None, "Translation timed out" | |
| except Exception as e: | |
| return None, str(e) | |
| def index(): | |
| pairs = [] | |
| for key, info in sorted(MODEL_REGISTRY.items()): | |
| pairs.append({ | |
| "key": key, | |
| "src_lang": info["src_lang"], "trg_lang": info["trg_lang"], | |
| "src_name": info["src_name"], "trg_name": info["trg_name"], | |
| "src_flag": info["src_flag"], "trg_flag": info["trg_flag"], | |
| "label": f"{info['src_name']} \u2192 {info['trg_name']}", | |
| }) | |
| return render_template("index.html", pairs=pairs) | |
| def translate_endpoint(): | |
| data = request.get_json(force=True) | |
| pair_key = data.get("pair", "") | |
| text = data.get("text", "").strip() | |
| if not text: | |
| return jsonify({"error": "Please enter some text."}), 400 | |
| if pair_key not in MODEL_REGISTRY: | |
| return jsonify({"error": f"Unknown pair: {pair_key}"}), 400 | |
| translation, error = _translate_text(pair_key, text) | |
| if error: | |
| return jsonify({"error": error}), 500 | |
| return jsonify({ | |
| "translation": translation, | |
| "pair": pair_key, | |
| "src_name": MODEL_REGISTRY[pair_key]["src_name"], | |
| "trg_name": MODEL_REGISTRY[pair_key]["trg_name"], | |
| }) | |
| print("\n" + "=" * 50) | |
| print(" NMT Translator — discovering models...") | |
| print("=" * 50) | |
| _discover_models() | |
| print(f"\n {len(MODEL_REGISTRY)} model(s) ready.\n") | |
| if __name__ == "__main__": | |
| port = int(os.environ.get("PORT", 7860)) | |
| app.run(host="0.0.0.0", port=port) | |