R-TA commited on
Commit
58fb822
·
verified ·
1 Parent(s): 2e22be2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -0
app.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
+ import torch
4
+ import os
5
+
6
+ app = Flask(__name__)
7
+
8
+ # Production configuration
9
+ if os.environ.get('FLASK_ENV') == 'production':
10
+ app.config['DEBUG'] = False
11
+
12
+ # Load the NLLB-200 model
13
+ model_name = "facebook/nllb-200-distilled-600M"
14
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
15
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
16
+
17
+ # Set device
18
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+ model = model.to(device)
20
+
21
+ @app.route("/")
22
+ def home():
23
+ return jsonify({"status": "✅ NLLB Translator is running!"})
24
+
25
+ @app.route("/translate", methods=["POST"])
26
+ def translate():
27
+ try:
28
+ data = request.get_json()
29
+ text = data.get("text")
30
+ source_lang = data.get("source_lang")
31
+ target_lang = data.get("target_lang")
32
+
33
+ if not text or not source_lang or not target_lang:
34
+ return jsonify({
35
+ "error": "Please provide 'text', 'source_lang', and 'target_lang'."
36
+ }), 400
37
+
38
+ tokenizer.src_lang = source_lang
39
+ encoded = tokenizer(text, return_tensors="pt").to(device)
40
+ generated = model.generate(
41
+ **encoded,
42
+ forced_bos_token_id=tokenizer.lang_code_to_id[target_lang],
43
+ max_length=512
44
+ )
45
+ translation = tokenizer.batch_decode(generated, skip_special_tokens=True)[0]
46
+
47
+ return jsonify({
48
+ "input": text,
49
+ "source_lang": source_lang,
50
+ "target_lang": target_lang,
51
+ "translation": translation
52
+ })
53
+ except Exception as e:
54
+ return jsonify({"error": str(e)}), 500
55
+
56
+ if __name__ == "__main__":
57
+ app.run(host="0.0.0.0", port=8080)