File size: 2,138 Bytes
3f9181a
 
 
05e7f60
 
3f9181a
05e7f60
3f9181a
05e7f60
3f9181a
 
05e7f60
3f9181a
 
 
 
 
 
 
 
05e7f60
3f9181a
 
 
05e7f60
3f9181a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
05e7f60
3f9181a
 
 
 
05e7f60
3f9181a
 
05e7f60
3f9181a
 
05e7f60
3f9181a
 
 
 
 
 
05e7f60
3f9181a
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
# app.py
from flask import Flask, request, jsonify
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
import os

app = Flask(__name__)

MODEL_ID = os.environ.get("MODEL_ID", "Helsinki-NLP/opus-mt-en-ar")

# Lazy load on first request (avoid heavy imports on cold boot if you prefer)
translator = None

def get_translator():
    global translator
    if translator is None:
        # Load tokenizer + model explicitly to control device/kwargs if needed
        tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
        model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)
        translator = pipeline("translation", model=model, tokenizer=tokenizer, src="en", tgt="ar")
    return translator

@app.route("/health", methods=["GET"])
def health():
    return jsonify({"status": "ok"}), 200

@app.route("/translate", methods=["POST"])
def translate():
    """
    Accepts JSON:
    {
      "texts": ["Hello", "How are you?"],   # or a single string as "text"
      "max_length": 256,                    # optional
      "batch_size": 8                       # optional
    }
    Returns:
    {
      "translations": ["مرحبا", "كيف حالك؟"]
    }
    """
    payload = request.get_json(force=True)
    if payload is None:
        return jsonify({"error": "invalid json"}), 400

    # allow single text or list
    texts = payload.get("texts") or payload.get("text")
    if texts is None:
        return jsonify({"error": "provide 'text' or 'texts' in JSON"}), 400

    if isinstance(texts, str):
        texts = [texts]

    max_length = payload.get("max_length", 256)
    batch_size = payload.get("batch_size", 8)

    pipe = get_translator()
    # pipeline supports batched translation
    translated = pipe(texts, max_length=max_length, batch_size=batch_size)
    # pipeline returns list of dicts like {"translation_text": "..."}
    out = [t["translation_text"] for t in translated]
    return jsonify({"translations": out}), 200

if __name__ == "__main__":
    # For local debug only; Spaces uses gunicorn (Dockerfile will define)
    app.run(host="0.0.0.0", port=int(os.environ.get("PORT", 8080)))