|
|
import requests |
|
|
import os |
|
|
import ast |
|
|
import streamlit as st |
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
import matplotlib.pyplot as plt |
|
|
import plotly.express as px |
|
|
from wordcloud import WordCloud |
|
|
from collections import Counter |
|
|
import torch |
|
|
from transformers import AutoTokenizer |
|
|
import joblib |
|
|
from model import MultiLabelDeberta |
|
|
from huggingface_hub import hf_hub_download |
|
|
from datasets import load_dataset |
|
|
|
|
|
|
|
|
st.set_page_config(page_title="Tag Predictor", layout="wide") |
|
|
|
|
|
|
|
|
st.markdown(""" |
|
|
<style> |
|
|
textarea { |
|
|
font-size: 18px !important; |
|
|
} |
|
|
|
|
|
.markdown-text-container h1 { |
|
|
font-size: 34px !important; |
|
|
} |
|
|
|
|
|
.markdown-text-container h2 { |
|
|
font-size: 28px !important; |
|
|
} |
|
|
|
|
|
.markdown-text-container h3 { |
|
|
font-size: 24px !important; |
|
|
} |
|
|
|
|
|
.stSlider .css-1y4p8pa, .stSlider .css-1cpxqw2 { |
|
|
font-size: 18px !important; |
|
|
} |
|
|
|
|
|
.stButton > button { |
|
|
font-size: 18px !important; |
|
|
} |
|
|
|
|
|
.stAlert { |
|
|
font-size: 18px !important; |
|
|
} |
|
|
|
|
|
.stMarkdown p { |
|
|
font-size: 18px !important; |
|
|
} |
|
|
</style> |
|
|
""", unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
|
|
|
REPO_ID = "Framby/deberta_multilabel" |
|
|
mlb_path = hf_hub_download(repo_id=REPO_ID, filename="mlb.joblib") |
|
|
mlb = joblib.load(mlb_path) |
|
|
deberta_path = hf_hub_download( |
|
|
repo_id=REPO_ID, filename="deberta_multilabel.pt") |
|
|
|
|
|
|
|
|
@st.cache_resource |
|
|
def load_model_and_tokenizer(): |
|
|
mlb = joblib.load(mlb_path) |
|
|
model = MultiLabelDeberta(num_labels=len(mlb.classes_)) |
|
|
model.load_state_dict(torch.load( |
|
|
deberta_path, map_location="cpu", weights_only=False)) |
|
|
model.eval() |
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
|
"microsoft/deberta-v3-base", use_fast=False) |
|
|
return model, tokenizer, mlb |
|
|
|
|
|
|
|
|
model, tokenizer, mlb = load_model_and_tokenizer() |
|
|
|
|
|
|
|
|
|
|
|
@st.cache_data |
|
|
def load_data(): |
|
|
ds = load_dataset("Framby/SOF_full")['train'] |
|
|
X = pd.Series(ds['text_clean']) |
|
|
Y = pd.Series(ds['Tags']) |
|
|
return X, Y |
|
|
|
|
|
|
|
|
X, Y = load_data() |
|
|
|
|
|
|
|
|
|
|
|
def predict_tags(text, threshold=0.5): |
|
|
inputs = tokenizer( |
|
|
text, |
|
|
return_tensors='pt', |
|
|
truncation=True, |
|
|
max_length=512, |
|
|
padding='max_length' |
|
|
) |
|
|
inputs.pop('token_type_ids', None) |
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
probs = torch.sigmoid(outputs).squeeze().cpu().numpy() |
|
|
binary_preds = (probs >= threshold).astype(int) |
|
|
predicted_tags = mlb.inverse_transform( |
|
|
np.expand_dims(binary_preds, axis=0)) |
|
|
return predicted_tags[0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.title("Prédicteur de Tags StackOverflow") |
|
|
|
|
|
st.markdown("## 1. Analyse des données textuelles") |
|
|
|
|
|
col1, col2 = st.columns(2) |
|
|
|
|
|
with col1: |
|
|
st.markdown("### Questions") |
|
|
text_lengths = X.apply(lambda x: len(x.split())) |
|
|
df_lengths = pd.DataFrame({'length': text_lengths}) |
|
|
fig = px.histogram(df_lengths, x='length', nbins=30, title="Distribution de la longueur des questions") |
|
|
st.plotly_chart(fig, use_container_width=True) |
|
|
|
|
|
with col2: |
|
|
st.markdown("### Tags") |
|
|
parsed_tags = Y.apply(ast.literal_eval) |
|
|
all_tags = [tag for sublist in parsed_tags for tag in sublist] |
|
|
tag_freq = Counter(all_tags) |
|
|
most_common_tags = pd.DataFrame(tag_freq.most_common(20), columns=['Tag', 'Nombre']) |
|
|
fig2 = px.bar(most_common_tags, x='Tag', y='Nombre', title="20 tags les plus fréquents") |
|
|
st.plotly_chart(fig2, use_container_width=True) |
|
|
|
|
|
st.markdown("### Nuage de mots") |
|
|
wc = WordCloud(width=800, height=300, background_color='white').generate(" ".join(X)) |
|
|
fig_wc, ax = plt.subplots(figsize=(10, 4)) |
|
|
ax.imshow(wc, interpolation='bilinear') |
|
|
ax.axis("off") |
|
|
st.pyplot(fig_wc) |
|
|
|
|
|
st.markdown("---") |
|
|
st.markdown("## 2. Prédiction des tags") |
|
|
|
|
|
input_text = st.text_area("Entrez une question StackOverflow", height=150) |
|
|
threshold = st.slider("Seuil de probabilité", 0.1, 0.9, 0.5, 0.05) |
|
|
|
|
|
if st.button("Prédire les tags"): |
|
|
if input_text.strip(): |
|
|
tags = predict_tags(input_text, threshold) |
|
|
if tags: |
|
|
st.success("Tags prédits :") |
|
|
st.write(", ".join(tags)) |
|
|
else: |
|
|
st.warning("Aucun tag trouvé pour le seuil sélectionné.") |
|
|
else: |
|
|
st.warning("Veuillez entrer une question.") |
|
|
|