File size: 4,348 Bytes
6c6ac72 474bcaa 8beae03 6c6ac72 a8d9c94 8beae03 6428414 8beae03 8986719 6428414 8986719 6428414 8986719 6428414 8986719 6428414 8986719 6428414 8986719 6428414 8986719 6428414 8986719 6428414 6c6ac72 166de0f 438829f 6c6ac72 8beae03 6c6ac72 8beae03 6c6ac72 8beae03 ecdeea4 8beae03 a8d9c94 0ec2781 8beae03 a8d9c94 8beae03 ecdeea4 8beae03 ecdeea4 52d0414 8beae03 474bcaa 8beae03 239922e 8beae03 7f7f74a 474bcaa eac340a 7f7f74a 8beae03 a10c1f1 8beae03 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import requests
import os
import ast
import streamlit as st
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import plotly.express as px
from wordcloud import WordCloud
from collections import Counter
import torch
from transformers import AutoTokenizer
import joblib
from model import MultiLabelDeberta
from huggingface_hub import hf_hub_download
from datasets import load_dataset
st.set_page_config(page_title="Tag Predictor", layout="wide")
# ========== Style ==========
st.markdown("""
<style>
textarea {
font-size: 18px !important;
}
.markdown-text-container h1 {
font-size: 34px !important;
}
.markdown-text-container h2 {
font-size: 28px !important;
}
.markdown-text-container h3 {
font-size: 24px !important;
}
.stSlider .css-1y4p8pa, .stSlider .css-1cpxqw2 {
font-size: 18px !important;
}
.stButton > button {
font-size: 18px !important;
}
.stAlert {
font-size: 18px !important;
}
.stMarkdown p {
font-size: 18px !important;
}
</style>
""", unsafe_allow_html=True)
# ========== Loading model and data ==========
REPO_ID = "Framby/deberta_multilabel"
mlb_path = hf_hub_download(repo_id=REPO_ID, filename="mlb.joblib")
mlb = joblib.load(mlb_path)
deberta_path = hf_hub_download(
repo_id=REPO_ID, filename="deberta_multilabel.pt")
@st.cache_resource
def load_model_and_tokenizer():
mlb = joblib.load(mlb_path)
model = MultiLabelDeberta(num_labels=len(mlb.classes_))
model.load_state_dict(torch.load(
deberta_path, map_location="cpu", weights_only=False))
model.eval()
tokenizer = AutoTokenizer.from_pretrained(
"microsoft/deberta-v3-base", use_fast=False)
return model, tokenizer, mlb
model, tokenizer, mlb = load_model_and_tokenizer()
# ========== data loading ==========
@st.cache_data
def load_data():
ds = load_dataset("Framby/SOF_full")['train']
X = pd.Series(ds['text_clean'])
Y = pd.Series(ds['Tags'])
return X, Y
X, Y = load_data()
# ========== prediction function ==========
def predict_tags(text, threshold=0.5):
inputs = tokenizer(
text,
return_tensors='pt',
truncation=True,
max_length=512,
padding='max_length'
)
inputs.pop('token_type_ids', None)
with torch.no_grad():
outputs = model(**inputs)
probs = torch.sigmoid(outputs).squeeze().cpu().numpy()
binary_preds = (probs >= threshold).astype(int)
predicted_tags = mlb.inverse_transform(
np.expand_dims(binary_preds, axis=0))
return predicted_tags[0]
# ========== interface ==========
st.title("Prédicteur de Tags StackOverflow")
st.markdown("## 1. Analyse des données textuelles")
col1, col2 = st.columns(2)
with col1:
st.markdown("### Questions")
text_lengths = X.apply(lambda x: len(x.split()))
df_lengths = pd.DataFrame({'length': text_lengths})
fig = px.histogram(df_lengths, x='length', nbins=30, title="Distribution de la longueur des questions")
st.plotly_chart(fig, use_container_width=True)
with col2:
st.markdown("### Tags")
parsed_tags = Y.apply(ast.literal_eval)
all_tags = [tag for sublist in parsed_tags for tag in sublist]
tag_freq = Counter(all_tags)
most_common_tags = pd.DataFrame(tag_freq.most_common(20), columns=['Tag', 'Nombre'])
fig2 = px.bar(most_common_tags, x='Tag', y='Nombre', title="20 tags les plus fréquents")
st.plotly_chart(fig2, use_container_width=True)
st.markdown("### Nuage de mots")
wc = WordCloud(width=800, height=300, background_color='white').generate(" ".join(X))
fig_wc, ax = plt.subplots(figsize=(10, 4))
ax.imshow(wc, interpolation='bilinear')
ax.axis("off")
st.pyplot(fig_wc)
st.markdown("---")
st.markdown("## 2. Prédiction des tags")
input_text = st.text_area("Entrez une question StackOverflow", height=150)
threshold = st.slider("Seuil de probabilité", 0.1, 0.9, 0.5, 0.05)
if st.button("Prédire les tags"):
if input_text.strip():
tags = predict_tags(input_text, threshold)
if tags:
st.success("Tags prédits :")
st.write(", ".join(tags))
else:
st.warning("Aucun tag trouvé pour le seuil sélectionné.")
else:
st.warning("Veuillez entrer une question.")
|