R-TA commited on
Commit
17063ae
·
verified ·
1 Parent(s): a232b1c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -28
app.py CHANGED
@@ -1,18 +1,26 @@
1
  from flask import Flask, request, jsonify
2
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
- import torch
4
  import os
5
 
6
  app = Flask(__name__)
7
- model = None
8
- tokenizer = None
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
 
11
- def load_model():
12
- global model, tokenizer
13
- model_name = "facebook/nllb-200-distilled-600M"
14
- tokenizer = AutoTokenizer.from_pretrained(model_name)
15
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  @app.route("/")
18
  def home():
@@ -21,24 +29,12 @@ def home():
21
  @app.route("/translate", methods=["POST"])
22
  def translate():
23
  data = request.get_json()
24
- text = data.get("text", "")
25
- source_lang = data.get("source_lang", "eng_Latn")
26
- target_lang = data.get("target_lang", "urd_Arab")
27
-
28
- inputs = tokenizer(text, return_tensors="pt").to(device)
29
- outputs = model.generate(
30
- **inputs,
31
- forced_bos_token_id=tokenizer.lang_code_to_id[target_lang],
32
- max_length=512
33
  )
34
- translation = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
35
-
36
- return jsonify({
37
- "translation": translation,
38
- "source_lang": source_lang,
39
- "target_lang": target_lang
40
- })
41
 
42
  if __name__ == "__main__":
43
- load_model()
44
- app.run(host="0.0.0.0", port=7860, debug=True)
 
1
  from flask import Flask, request, jsonify
2
+ import requests
 
3
  import os
4
 
5
  app = Flask(__name__)
6
+ HF_SPACE_URL = "https://R-TA-NLLB-200-AI.hf.space"
 
 
7
 
8
+ def translate_text(text, source_lang="eng_Latn", target_lang="urd_Arab"):
9
+ try:
10
+ response = requests.post(
11
+ f"{HF_SPACE_URL}/translate",
12
+ json={
13
+ "inputs": {
14
+ "text": text,
15
+ "source_lang": source_lang,
16
+ "target_lang": target_lang
17
+ }
18
+ },
19
+ timeout=30
20
+ )
21
+ return response.json().get("translation", "")
22
+ except Exception as e:
23
+ return str(e)
24
 
25
  @app.route("/")
26
  def home():
 
29
  @app.route("/translate", methods=["POST"])
30
  def translate():
31
  data = request.get_json()
32
+ translation = translate_text(
33
+ data.get("text", ""),
34
+ data.get("source_lang", "eng_Latn"),
35
+ data.get("target_lang", "urd_Arab")
 
 
 
 
 
36
  )
37
+ return jsonify({"translation": translation})
 
 
 
 
 
 
38
 
39
  if __name__ == "__main__":
40
+ app.run(host="0.0.0.0", port=8080)