Ptul2x5 commited on
Commit
1ec82b8
·
verified ·
1 Parent(s): 663dc0a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -9
app.py CHANGED
@@ -15,21 +15,20 @@ print("🔄 Đang tải tokenizer và model từ Hugging Face...")
15
  tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO, use_fast=False)
16
 
17
  # 🔹 Tải trọng số model (.bin) trực tiếp từ Hugging Face
18
- state_dict = torch.hub.load_state_dict_from_url(
19
- f"https://huggingface.co/{MODEL_REPO}/resolve/main/multitask_model.bin",
20
- map_location=device
21
- )
22
 
23
  model = PhoBERTMultiTask(num_sentiment=3, num_topic=4)
24
- model.load_state_dict(state_dict)
25
  model.to(device)
26
  model.eval()
 
27
  print("✅ Model đã sẵn sàng!")
28
 
29
  # ====== ROUTES ======
30
  @app.route("/", methods=["GET"])
31
  def home():
32
- return render_template('index.html')
33
 
34
  @app.route("/api/health", methods=["GET"])
35
  def health():
@@ -49,7 +48,9 @@ def predict():
49
  return jsonify({"error": "Text quá dài. Vui lòng nhập tối đa 1000 ký tự."}), 400
50
 
51
  # Tokenize
52
- inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128).to(device)
 
 
53
 
54
  # Inference
55
  with torch.no_grad():
@@ -74,5 +75,5 @@ def predict():
74
 
75
 
76
  if __name__ == "__main__":
77
- port = int(os.environ.get("PORT", 10000))
78
- app.run(host="0.0.0.0", port=port)
 
15
  tokenizer = AutoTokenizer.from_pretrained(MODEL_REPO, use_fast=False)
16
 
17
  # 🔹 Tải trọng số model (.bin) trực tiếp từ Hugging Face
18
+ MODEL_URL = f"https://huggingface.co/{MODEL_REPO}/resolve/main/multitask_model.bin"
19
+ state_dict = torch.hub.load_state_dict_from_url(MODEL_URL, map_location=device)
 
 
20
 
21
  model = PhoBERTMultiTask(num_sentiment=3, num_topic=4)
22
+ model.load_state_dict(state_dict, strict=False)
23
  model.to(device)
24
  model.eval()
25
+
26
  print("✅ Model đã sẵn sàng!")
27
 
28
  # ====== ROUTES ======
29
  @app.route("/", methods=["GET"])
30
  def home():
31
+ return render_template("index.html")
32
 
33
  @app.route("/api/health", methods=["GET"])
34
  def health():
 
48
  return jsonify({"error": "Text quá dài. Vui lòng nhập tối đa 1000 ký tự."}), 400
49
 
50
  # Tokenize
51
+ inputs = tokenizer(
52
+ text, return_tensors="pt", truncation=True, padding=True, max_length=128
53
+ ).to(device)
54
 
55
  # Inference
56
  with torch.no_grad():
 
75
 
76
 
77
  if __name__ == "__main__":
78
+ # Hugging Face luôn yêu cầu port = 7860
79
+ app.run(host="0.0.0.0", port=7860)