Framby commited on
Commit
6c6ac72
·
1 Parent(s): 8e55491
Files changed (2) hide show
  1. app.py +10 -29
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import streamlit as st
2
  import pandas as pd
3
  import numpy as np
@@ -9,17 +11,23 @@ import torch
9
  from transformers import AutoTokenizer
10
  import joblib
11
  from model import MultiLabelDeberta
 
12
 
13
  # ========== Loading model and data ==========
14
  st.set_page_config(page_title="Tag Predictor", layout="wide")
15
 
 
 
 
 
 
16
 
17
  @st.cache_resource
18
  def load_model_and_tokenizer():
19
- mlb = joblib.load("mlb.pkl")
20
  model = MultiLabelDeberta(num_labels=len(mlb.classes_))
21
  model.load_state_dict(torch.load(
22
- "deberta_multilabel.pt", map_location="cpu"))
23
  model.eval()
24
  tokenizer = AutoTokenizer.from_pretrained(
25
  "microsoft/deberta-v3-base", use_fast=False)
@@ -28,33 +36,6 @@ def load_model_and_tokenizer():
28
 
29
  model, tokenizer, mlb = load_model_and_tokenizer()
30
 
31
- import os
32
- import requests
33
-
34
- def download_from_gdrive(file_id, dest_path):
35
- URL = "https://drive.google.com/uc?export=download"
36
- session = requests.Session()
37
- response = session.get(URL, params={'id': file_id}, stream=True)
38
- token = None
39
- for key, value in response.cookies.items():
40
- if key.startswith('download_warning'):
41
- token = value
42
- if token:
43
- params = {'id': file_id, 'confirm': token}
44
- response = session.get(URL, params=params, stream=True)
45
- with open(dest_path, "wb") as f:
46
- for chunk in response.iter_content(32768):
47
- if chunk:
48
- f.write(chunk)
49
-
50
-
51
- if not os.path.exists("deberta_multilabel.pt"):
52
- download_from_gdrive("1XE_nJwFJwdZj2-I4gH6kAfGuOBczlRzf", "deberta_multilabel.pt")
53
-
54
-
55
- if not os.path.exists("mlb.pkl"):
56
- download_from_gdrive("1M2_AVSu9VxAR9NJg75x3UHxiw-2laNCh", "mlb.pkl")
57
-
58
  # ========== data loading ==========
59
 
60
 
 
1
+ import requests
2
+ import os
3
  import streamlit as st
4
  import pandas as pd
5
  import numpy as np
 
11
  from transformers import AutoTokenizer
12
  import joblib
13
  from model import MultiLabelDeberta
14
+ from huggingface_hub import hf_hub_download
15
 
16
  # ========== Loading model and data ==========
17
  st.set_page_config(page_title="Tag Predictor", layout="wide")
18
 
19
+ REPO_ID = "Framby/deberta_multilabel"
20
+ deberta_path = hf_hub_download(
21
+ repo_id=REPO_ID, filename="deberta_multilabel.pt")
22
+ mlb_path = hf_hub_download(repo_id=REPO_ID, filename="mlb.pkl")
23
+
24
 
25
  @st.cache_resource
26
  def load_model_and_tokenizer():
27
+ mlb = joblib.load(mlb_path)
28
  model = MultiLabelDeberta(num_labels=len(mlb.classes_))
29
  model.load_state_dict(torch.load(
30
+ deberta_path, map_location="cpu", weights_only=False))
31
  model.eval()
32
  tokenizer = AutoTokenizer.from_pretrained(
33
  "microsoft/deberta-v3-base", use_fast=False)
 
36
 
37
  model, tokenizer, mlb = load_model_and_tokenizer()
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  # ========== data loading ==========
40
 
41
 
requirements.txt CHANGED
@@ -2,6 +2,7 @@
2
  pandas>=1.3.0
3
  numpy>=1.21.0
4
  requests>=2.31.0
 
5
 
6
  # === Visualization ===
7
  matplotlib>=3.5.0
 
2
  pandas>=1.3.0
3
  numpy>=1.21.0
4
  requests>=2.31.0
5
+ huggingface_hub
6
 
7
  # === Visualization ===
8
  matplotlib>=3.5.0