R-TA commited on
Commit
966930e
·
verified ·
1 Parent(s): 284bda6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -41
app.py CHANGED
@@ -4,19 +4,15 @@ import torch
4
  import os
5
 
6
  app = Flask(__name__)
 
 
 
7
 
8
- # Production configuration
9
- if os.environ.get('FLASK_ENV') == 'production':
10
- app.config['DEBUG'] = False
11
-
12
- # Load the NLLB-200 model
13
- model_name = "facebook/nllb-200-distilled-600M"
14
- tokenizer = AutoTokenizer.from_pretrained(model_name)
15
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
16
-
17
- # Set device
18
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
- model = model.to(device)
20
 
21
  @app.route("/")
22
  def home():
@@ -24,34 +20,25 @@ def home():
24
 
25
  @app.route("/translate", methods=["POST"])
26
  def translate():
27
- try:
28
- data = request.get_json()
29
- text = data.get("text")
30
- source_lang = data.get("source_lang")
31
- target_lang = data.get("target_lang")
32
-
33
- if not text or not source_lang or not target_lang:
34
- return jsonify({
35
- "error": "Please provide 'text', 'source_lang', and 'target_lang'."
36
- }), 400
37
-
38
- tokenizer.src_lang = source_lang
39
- encoded = tokenizer(text, return_tensors="pt").to(device)
40
- generated = model.generate(
41
- **encoded,
42
- forced_bos_token_id=tokenizer.lang_code_to_id[target_lang],
43
- max_length=512
44
- )
45
- translation = tokenizer.batch_decode(generated, skip_special_tokens=True)[0]
46
-
47
- return jsonify({
48
- "input": text,
49
- "source_lang": source_lang,
50
- "target_lang": target_lang,
51
- "translation": translation
52
- })
53
- except Exception as e:
54
- return jsonify({"error": str(e)}), 500
55
 
56
  if __name__ == "__main__":
57
- app.run(host="0.0.0.0", port=8080)
 
 
4
  import os
5
 
6
  app = Flask(__name__)
7
+ model = None
8
+ tokenizer = None
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
 
11
+ def load_model():
12
+ global model, tokenizer
13
+ model_name = "facebook/nllb-200-distilled-600M"
14
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
15
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
 
 
 
 
 
 
 
16
 
17
  @app.route("/")
18
  def home():
 
20
 
21
  @app.route("/translate", methods=["POST"])
22
  def translate():
23
+ data = request.get_json()
24
+ text = data.get("text", "")
25
+ source_lang = data.get("source_lang", "eng_Latn")
26
+ target_lang = data.get("target_lang", "urd_Arab")
27
+
28
+ inputs = tokenizer(text, return_tensors="pt").to(device)
29
+ outputs = model.generate(
30
+ **inputs,
31
+ forced_bos_token_id=tokenizer.lang_code_to_id[target_lang],
32
+ max_length=512
33
+ )
34
+ translation = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
35
+
36
+ return jsonify({
37
+ "translation": translation,
38
+ "source_lang": source_lang,
39
+ "target_lang": target_lang
40
+ })
 
 
 
 
 
 
 
 
 
 
41
 
42
  if __name__ == "__main__":
43
+ load_model()
44
+ app.run(host="0.0.0.0", port=7860, debug=True)