Framby commited on
Commit
ecdeea4
·
1 Parent(s): 90a6700

Adding links for load models

Browse files
Files changed (2) hide show
  1. app.py +31 -6
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,5 +1,3 @@
1
- # app.py
2
-
3
  import streamlit as st
4
  import pandas as pd
5
  import numpy as np
@@ -12,7 +10,7 @@ from transformers import AutoTokenizer
12
  import joblib
13
  from model import MultiLabelDeberta
14
 
15
- # ========== Загрузка модели и данных ==========
16
  st.set_page_config(page_title="Tag Predictor", layout="wide")
17
 
18
 
@@ -30,7 +28,34 @@ def load_model_and_tokenizer():
30
 
31
  model, tokenizer, mlb = load_model_and_tokenizer()
32
 
33
- # ========== Загрузка данных ==========
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
 
36
  @st.cache_data
@@ -42,7 +67,7 @@ def load_data():
42
 
43
  X, Y = load_data()
44
 
45
- # ========== Функция предсказания ==========
46
 
47
 
48
  def predict_tags(text, threshold=0.5):
@@ -63,7 +88,7 @@ def predict_tags(text, threshold=0.5):
63
  return predicted_tags[0]
64
 
65
 
66
- # ========== Интерфейс ==========
67
  st.title("Prédicteur de Tags StackOverflow")
68
 
69
  st.markdown("## 1. Analyse des données textuelles")
 
 
 
1
  import streamlit as st
2
  import pandas as pd
3
  import numpy as np
 
10
  import joblib
11
  from model import MultiLabelDeberta
12
 
13
+ # ========== Loading model and data ==========
14
  st.set_page_config(page_title="Tag Predictor", layout="wide")
15
 
16
 
 
28
 
29
  model, tokenizer, mlb = load_model_and_tokenizer()
30
 
31
+ import os
32
+ import requests
33
+
34
+ def download_from_gdrive(file_id, dest_path):
35
+ URL = "https://drive.google.com/uc?export=download"
36
+ session = requests.Session()
37
+ response = session.get(URL, params={'id': file_id}, stream=True)
38
+ token = None
39
+ for key, value in response.cookies.items():
40
+ if key.startswith('download_warning'):
41
+ token = value
42
+ if token:
43
+ params = {'id': file_id, 'confirm': token}
44
+ response = session.get(URL, params=params, stream=True)
45
+ with open(dest_path, "wb") as f:
46
+ for chunk in response.iter_content(32768):
47
+ if chunk:
48
+ f.write(chunk)
49
+
50
+
51
+ if not os.path.exists("deberta_multilabel.pt"):
52
+ download_from_gdrive("1XE_nJwFJwdZj2-I4gH6kAfGuOBczlRzf", "deberta_multilabel.pt")
53
+
54
+
55
+ if not os.path.exists("mlb.pkl"):
56
+ download_from_gdrive("1M2_AVSu9VxAR9NJg75x3UHxiw-2laNCh", "mlb.pkl")
57
+
58
+ # ========== data loading ==========
59
 
60
 
61
  @st.cache_data
 
67
 
68
  X, Y = load_data()
69
 
70
+ # ========== prediction function ==========
71
 
72
 
73
  def predict_tags(text, threshold=0.5):
 
88
  return predicted_tags[0]
89
 
90
 
91
+ # ========== interface ==========
92
  st.title("Prédicteur de Tags StackOverflow")
93
 
94
  st.markdown("## 1. Analyse des données textuelles")
requirements.txt CHANGED
@@ -1,6 +1,7 @@
1
  # === Core data libraries ===
2
  pandas>=1.3.0
3
  numpy>=1.21.0
 
4
 
5
  # === Visualization ===
6
  matplotlib>=3.5.0
 
1
  # === Core data libraries ===
2
  pandas>=1.3.0
3
  numpy>=1.21.0
4
+ requests>=2.31.0
5
 
6
  # === Visualization ===
7
  matplotlib>=3.5.0