Framby commited on
Commit
8beae03
·
1 Parent(s): 3ec8e71

First commit

Browse files
Files changed (5) hide show
  1. .gitignore +59 -0
  2. Pipfile +11 -0
  3. app.py +113 -0
  4. model.py +16 -0
  5. requirements.txt +31 -0
.gitignore ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # === Python bytecode ===
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # === Jupyter Notebooks checkpoints ===
7
+ .ipynb_checkpoints
8
+
9
+ # === Virtual environment ===
10
+ .venv/
11
+ venv/
12
+ env/
13
+ ENV/
14
+
15
+ # === OS files ===
16
+ .DS_Store
17
+ Thumbs.db
18
+
19
+ # === Streamlit cache ===
20
+ .streamlit/cache/
21
+ .streamlit/config.toml
22
+
23
+ # === PyTorch checkpoints and model files ===
24
+ *.pt
25
+ *.pth
26
+ *.bin
27
+
28
+ # === Tokenizer and transformers cache ===
29
+ .cache/
30
+ transformers_cache/
31
+ huggingface/
32
+
33
+ # === Dataset or outputs ===
34
+ *.csv
35
+ *.tsv
36
+ *.json
37
+ *.xlsx
38
+ *.log
39
+ *.npy
40
+ *.npz
41
+
42
+ # === Model artifacts ===
43
+ *.joblib
44
+ *.pkl
45
+
46
+ # === Environment files ===
47
+ .env
48
+ .env.*
49
+
50
+ # === VSCode / IDE ===
51
+ .vscode/
52
+ .idea/
53
+
54
+ # === Misc ===
55
+ *.zip
56
+ *.tar.gz
57
+ *.egg-info/
58
+ build/
59
+ dist/
Pipfile ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [[source]]
2
+ url = "https://pypi.org/simple"
3
+ verify_ssl = true
4
+ name = "pypi"
5
+
6
+ [packages]
7
+
8
+ [dev-packages]
9
+
10
+ [requires]
11
+ python_version = "3.12"
app.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+
3
+ import streamlit as st
4
+ import pandas as pd
5
+ import numpy as np
6
+ import matplotlib.pyplot as plt
7
+ import plotly.express as px
8
+ from wordcloud import WordCloud
9
+ from collections import Counter
10
+ import torch
11
+ from transformers import AutoTokenizer
12
+ import joblib
13
+ from model import MultiLabelDeberta
14
+
15
+ # ========== Загрузка модели и данных ==========
16
+ st.set_page_config(page_title="Tag Predictor", layout="wide")
17
+
18
+
19
+ @st.cache_resource
20
+ def load_model_and_tokenizer():
21
+ mlb = joblib.load("mlb.pkl")
22
+ model = MultiLabelDeberta(num_labels=len(mlb.classes_))
23
+ model.load_state_dict(torch.load(
24
+ "deberta_multilabel.pt", map_location="cpu"))
25
+ model.eval()
26
+ tokenizer = AutoTokenizer.from_pretrained(
27
+ "microsoft/deberta-v3-base", use_fast=False)
28
+ return model, tokenizer, mlb
29
+
30
+
31
+ model, tokenizer, mlb = load_model_and_tokenizer()
32
+
33
+ # ========== Загрузка данных ==========
34
+
35
+
36
+ @st.cache_data
37
+ def load_data():
38
+ X = pd.read_csv('X_text.csv')['text_clean'].astype(str)
39
+ Y = pd.read_csv('Y_tags.csv', converters={'Tags': eval})['Tags']
40
+ return X, Y
41
+
42
+
43
+ X, Y = load_data()
44
+
45
+ # ========== Функция предсказания ==========
46
+
47
+
48
+ def predict_tags(text, threshold=0.5):
49
+ inputs = tokenizer(
50
+ text,
51
+ return_tensors='pt',
52
+ truncation=True,
53
+ max_length=512,
54
+ padding='max_length'
55
+ )
56
+ inputs.pop('token_type_ids', None)
57
+ with torch.no_grad():
58
+ outputs = model(**inputs)
59
+ probs = torch.sigmoid(outputs).squeeze().cpu().numpy()
60
+ binary_preds = (probs >= threshold).astype(int)
61
+ predicted_tags = mlb.inverse_transform(
62
+ np.expand_dims(binary_preds, axis=0))
63
+ return predicted_tags[0]
64
+
65
+
66
+ # ========== Интерфейс ==========
67
+ st.title("Prédicteur de Tags StackOverflow")
68
+
69
+ st.markdown("## 1. Analyse des données textuelles")
70
+
71
+ col1, col2 = st.columns(2)
72
+
73
+ with col1:
74
+ st.markdown("### Distribution de la longueur des questions")
75
+ text_lengths = X.apply(lambda x: len(x.split()))
76
+ fig = px.histogram(text_lengths, nbins=30,
77
+ title="Distribution de la longueur des questions")
78
+ st.plotly_chart(fig, use_container_width=True)
79
+
80
+ with col2:
81
+ st.markdown("### Mots les plus fréquents")
82
+ all_words = " ".join(X).split()
83
+ word_freq = Counter(all_words)
84
+ most_common_words = pd.DataFrame(
85
+ word_freq.most_common(20), columns=['Mot', 'Nombre'])
86
+ fig2 = px.bar(most_common_words, x='Mot', y='Nombre',
87
+ title="20 mots les plus fréquents")
88
+ st.plotly_chart(fig2, use_container_width=True)
89
+
90
+ st.markdown("### Nuage de mots")
91
+ wc = WordCloud(width=800, height=300,
92
+ background_color='white').generate(" ".join(X))
93
+ fig_wc, ax = plt.subplots(figsize=(10, 4))
94
+ ax.imshow(wc, interpolation='bilinear')
95
+ ax.axis("off")
96
+ st.pyplot(fig_wc)
97
+
98
+ st.markdown("---")
99
+ st.markdown("## 2. Prédiction des tags")
100
+
101
+ input_text = st.text_area("Entrez une question StackOverflow", height=150)
102
+ threshold = st.slider("Seuil de probabilité", 0.1, 0.9, 0.5, 0.05)
103
+
104
+ if st.button("Prédire les tags"):
105
+ if input_text.strip():
106
+ tags = predict_tags(input_text, threshold)
107
+ if tags:
108
+ st.success("Tags prédits :")
109
+ st.write(", ".join(tags))
110
+ else:
111
+ st.warning("Aucun tag trouvé pour le seuil sélectionné.")
112
+ else:
113
+ st.warning("Veuillez entrer une question.")
model.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from transformers import AutoTokenizer, AutoModel
3
+
4
+ class MultiLabelDeberta(nn.Module):
5
+ def __init__(self, num_labels):
6
+ super().__init__()
7
+ self.backbone = AutoModel.from_pretrained('microsoft/deberta-v3-base')
8
+ self.dropout = nn.Dropout(0.3)
9
+ self.classifier = nn.Linear(self.backbone.config.hidden_size, num_labels)
10
+
11
+ def forward(self, input_ids, attention_mask):
12
+ outputs = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
13
+ pooled = outputs.last_hidden_state[:, 0] # [CLS]
14
+ pooled = self.dropout(pooled)
15
+ logits = self.classifier(pooled)
16
+ return logits
requirements.txt ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # === Core data libraries ===
2
+ pandas>=1.3.0
3
+ numpy>=1.21.0
4
+
5
+ # === Visualization ===
6
+ matplotlib>=3.5.0
7
+ plotly>=5.3.1
8
+ wordcloud>=1.8.1
9
+ pillow>=9.0.0
10
+
11
+ # === Web app interface ===
12
+ streamlit>=1.20.0
13
+ watchdog>=2.1.6 # improves file change detection in Streamlit
14
+
15
+ # === NLP & Transformers ===
16
+ torch>=2.0.0
17
+ transformers>=4.31.0
18
+ tokenizers>=0.13.3
19
+ joblib>=1.2.0
20
+
21
+ # === Text preprocessing ===
22
+ beautifulsoup4>=4.12.0
23
+ nltk>=3.8.1
24
+ regex>=2023.12.25
25
+
26
+ # === Progress bar (optional but common in model inference) ===
27
+ tqdm>=4.64.0
28
+
29
+ # === ML Utilities ===
30
+ scikit-learn>=1.3.0
31
+ sentencepiece>=0.1.99