Spaces:
Sleeping
Sleeping
Update src/streamlit_app.py
Browse files- src/streamlit_app.py +185 -78
src/streamlit_app.py
CHANGED
|
@@ -9,13 +9,15 @@ import plotly.graph_objects as go
|
|
| 9 |
import umap
|
| 10 |
import shap
|
| 11 |
import logging
|
| 12 |
-
import os
|
| 13 |
import seaborn as sns
|
| 14 |
import matplotlib.pyplot as plt
|
| 15 |
from sklearn.decomposition import PCA
|
| 16 |
from sklearn.cluster import KMeans
|
| 17 |
from sklearn.ensemble import RandomForestClassifier
|
|
|
|
| 18 |
import warnings
|
|
|
|
|
|
|
| 19 |
|
| 20 |
warnings.filterwarnings('ignore')
|
| 21 |
|
|
@@ -23,7 +25,7 @@ warnings.filterwarnings('ignore')
|
|
| 23 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 24 |
|
| 25 |
# Répertoire de sortie
|
| 26 |
-
output_dir = '
|
| 27 |
os.makedirs(output_dir, exist_ok=True)
|
| 28 |
|
| 29 |
# Biomarqueurs IRC
|
|
@@ -88,81 +90,155 @@ class OmicsVAE(nn.Module):
|
|
| 88 |
class_logits = self.fc_classify(z)
|
| 89 |
return outputs, z, mu, log_var, class_logits
|
| 90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
# Configuration de Streamlit
|
| 92 |
-
st.set_page_config(page_title="Analyse Multi-Omique", layout="wide")
|
| 93 |
st.markdown("""
|
| 94 |
<style>
|
| 95 |
-
.main {background-color: #1e1e1e; color: #ffffff;}
|
| 96 |
-
.stButton>button {
|
| 97 |
-
|
| 98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
</style>
|
| 100 |
""", unsafe_allow_html=True)
|
| 101 |
|
| 102 |
-
st.title("
|
| 103 |
-
st.markdown("
|
|
|
|
|
|
|
|
|
|
| 104 |
|
| 105 |
# Menu latéral
|
| 106 |
st.sidebar.header("Navigation")
|
| 107 |
-
page = st.sidebar.radio("
|
| 108 |
-
"
|
| 109 |
"Chargement des Données",
|
| 110 |
"Analyse Exploratoire",
|
| 111 |
-
"
|
| 112 |
"Scores de Risque",
|
| 113 |
"Analyse SHAP",
|
| 114 |
-
"
|
|
|
|
| 115 |
])
|
| 116 |
|
| 117 |
-
#
|
| 118 |
-
|
| 119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
st.markdown("""
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
- **
|
| 124 |
-
- **
|
| 125 |
-
- **
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
Naviguez via le menu latéral pour commencer.
|
| 130 |
""")
|
| 131 |
|
| 132 |
# Chargement des Données
|
| 133 |
elif page == "Chargement des Données":
|
| 134 |
-
st.header("Chargement des Données")
|
| 135 |
-
st.markdown("
|
| 136 |
|
| 137 |
-
# Chargement des fichiers
|
| 138 |
uploaded_files = {}
|
| 139 |
omics_types = ['génomique', 'transcriptomique', 'protéomique', 'métabolomique']
|
| 140 |
-
file_paths = {
|
| 141 |
-
'génomique': 'genomic_data.csv',
|
| 142 |
-
'transcriptomique': 'transcriptomic_data.csv',
|
| 143 |
-
'protéomique': 'proteomic_data.csv',
|
| 144 |
-
'métabolomique': 'metabolomic_data.csv'
|
| 145 |
-
}
|
| 146 |
for omic in omics_types:
|
| 147 |
-
uploaded_file = st.file_uploader(f"
|
| 148 |
if uploaded_file:
|
| 149 |
-
|
| 150 |
-
f.write(uploaded_file.getvalue())
|
| 151 |
|
| 152 |
-
|
| 153 |
-
model_file = st.file_uploader("Charger le modèle VAE pré-entraîné (PTH)", type="pth")
|
| 154 |
if model_file:
|
| 155 |
with open(os.path.join(output_dir, 'omics_vae_best_hyperparams.pth'), 'wb') as f:
|
| 156 |
f.write(model_file.getvalue())
|
| 157 |
|
| 158 |
-
if st.button("
|
| 159 |
try:
|
| 160 |
-
# Chargement des données
|
| 161 |
data_dict = {}
|
| 162 |
-
for omic,
|
| 163 |
-
df = pd.read_csv(
|
| 164 |
data_dict[omic] = df.drop(columns=['Status'])
|
| 165 |
-
labels = pd.read_csv(list(
|
| 166 |
le = LabelEncoder()
|
| 167 |
encoded_labels = pd.Series(le.fit_transform(labels), index=labels.index, name='Status')
|
| 168 |
common_samples = data_dict['génomique'].index
|
|
@@ -174,7 +250,6 @@ elif page == "Chargement des Données":
|
|
| 174 |
st.session_state['label_encoder'] = le
|
| 175 |
st.session_state['common_samples'] = common_samples
|
| 176 |
|
| 177 |
-
# Chargement du modèle
|
| 178 |
input_dims = [data_dict[omic].shape[1] for omic in data_dict]
|
| 179 |
model = OmicsVAE(
|
| 180 |
input_dims=input_dims,
|
|
@@ -187,22 +262,21 @@ elif page == "Chargement des Données":
|
|
| 187 |
st.session_state['input_dims'] = input_dims
|
| 188 |
st.success("Données et modèle chargés avec succès !")
|
| 189 |
except Exception as e:
|
| 190 |
-
st.error(f"Erreur
|
| 191 |
|
| 192 |
# Analyse Exploratoire
|
| 193 |
elif page == "Analyse Exploratoire":
|
| 194 |
st.header("Analyse Exploratoire des Données")
|
| 195 |
if 'data_dict' not in st.session_state:
|
| 196 |
-
st.warning("
|
| 197 |
else:
|
| 198 |
data_dict = st.session_state['data_dict']
|
| 199 |
labels = st.session_state['labels']
|
| 200 |
-
omic = st.selectbox("
|
| 201 |
biomarkers = [col for col in data_dict[omic].columns if col in irc_biomarkers]
|
| 202 |
|
| 203 |
-
# Matrice de corrélation
|
| 204 |
if biomarkers:
|
| 205 |
-
st.subheader(f"Matrice de Corrélation
|
| 206 |
corr_matrix = data_dict[omic][biomarkers].corr()
|
| 207 |
fig = go.Figure(data=go.Heatmap(
|
| 208 |
z=corr_matrix.values,
|
|
@@ -217,7 +291,6 @@ elif page == "Analyse Exploratoire":
|
|
| 217 |
fig.update_layout(title=f'Matrice de Corrélation ({omic})', template='plotly_dark')
|
| 218 |
st.plotly_chart(fig, use_container_width=True)
|
| 219 |
|
| 220 |
-
# Projection PCA 3D
|
| 221 |
st.subheader(f"Projection PCA 3D ({omic})")
|
| 222 |
pca = PCA(n_components=3)
|
| 223 |
pca_result = pca.fit_transform(data_dict[omic])
|
|
@@ -239,17 +312,17 @@ elif page == "Analyse Exploratoire":
|
|
| 239 |
fig.update_traces(marker=dict(size=5))
|
| 240 |
st.plotly_chart(fig, use_container_width=True)
|
| 241 |
|
| 242 |
-
#
|
| 243 |
-
elif page == "
|
| 244 |
-
st.header("
|
| 245 |
if 'data_dict' not in st.session_state:
|
| 246 |
-
st.warning("
|
| 247 |
else:
|
| 248 |
data_dict = st.session_state['data_dict']
|
| 249 |
labels = st.session_state['labels']
|
| 250 |
label_encoder = st.session_state['label_encoder']
|
| 251 |
-
n_clusters = st.slider("Nombre de clusters", 2, 10, 5, help="
|
| 252 |
-
if st.button("
|
| 253 |
combined_data = pd.concat([data_dict[omic] for omic in data_dict], axis=1)
|
| 254 |
kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
|
| 255 |
umap_reducer = umap.UMAP(n_components=2, n_neighbors=15, min_dist=0.1, random_state=42)
|
|
@@ -261,7 +334,6 @@ elif page == "Visualisation du Clustering":
|
|
| 261 |
st.session_state['kmeans'] = kmeans
|
| 262 |
st.session_state['umap_embedding'] = umap_embedding
|
| 263 |
|
| 264 |
-
# Visualisation
|
| 265 |
fig = px.scatter(
|
| 266 |
umap_df, x='UMAP1', y='UMAP2', color='Cluster', symbol='Status',
|
| 267 |
title='Projection UMAP avec Clusters KMeans',
|
|
@@ -274,9 +346,9 @@ elif page == "Visualisation du Clustering":
|
|
| 274 |
|
| 275 |
# Scores de Risque
|
| 276 |
elif page == "Scores de Risque":
|
| 277 |
-
st.header("
|
| 278 |
if 'umap_df' not in st.session_state or 'data_dict' not in st.session_state:
|
| 279 |
-
st.warning("
|
| 280 |
else:
|
| 281 |
umap_df = st.session_state['umap_df']
|
| 282 |
data_dict = st.session_state['data_dict']
|
|
@@ -284,8 +356,8 @@ elif page == "Scores de Risque":
|
|
| 284 |
label_encoder = st.session_state['label_encoder']
|
| 285 |
kmeans = st.session_state['kmeans']
|
| 286 |
umap_embedding = st.session_state['umap_embedding']
|
| 287 |
-
|
| 288 |
-
if st.button("Calculer les Scores
|
| 289 |
cluster_centers = kmeans.cluster_centers_
|
| 290 |
distances = np.zeros(len(umap_embedding))
|
| 291 |
for i, emb in enumerate(umap_embedding):
|
|
@@ -308,7 +380,6 @@ elif page == "Scores de Risque":
|
|
| 308 |
umap_df['Score de Risque (%)'] = final_risk
|
| 309 |
st.session_state['umap_df'] = umap_df
|
| 310 |
|
| 311 |
-
# Visualisation
|
| 312 |
fig = px.scatter(
|
| 313 |
umap_df, x='UMAP1', y='UMAP2', color='Score de Risque (%)', symbol='Status',
|
| 314 |
title='Projection UMAP avec Scores de Risque IRC (%)',
|
|
@@ -321,22 +392,22 @@ elif page == "Scores de Risque":
|
|
| 321 |
|
| 322 |
# Analyse SHAP
|
| 323 |
elif page == "Analyse SHAP":
|
| 324 |
-
st.header("Analyse SHAP")
|
| 325 |
if 'model' not in st.session_state or 'data_dict' not in st.session_state:
|
| 326 |
-
st.warning("
|
| 327 |
else:
|
| 328 |
model = st.session_state['model']
|
| 329 |
data_dict = st.session_state['data_dict']
|
| 330 |
input_dims = st.session_state['input_dims']
|
| 331 |
combined_data = pd.concat([data_dict[omic] for omic in data_dict], axis=1)
|
| 332 |
feature_names = sum([data_dict[omic].columns.tolist() for omic in data_dict], [])
|
| 333 |
-
|
| 334 |
-
if st.button("Lancer l
|
| 335 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 336 |
X_concat = combined_data.values
|
| 337 |
n_samples = min(100, X_concat.shape[0])
|
| 338 |
X_subset = X_concat[:n_samples]
|
| 339 |
-
|
| 340 |
class VAEWrapper:
|
| 341 |
def __init__(self, model, device):
|
| 342 |
self.model = model
|
|
@@ -350,37 +421,73 @@ elif page == "Analyse SHAP":
|
|
| 350 |
with torch.no_grad():
|
| 351 |
_, z, _, _, _ = self.model(X_tensors)
|
| 352 |
return torch.norm(z, dim=1).cpu().numpy()
|
| 353 |
-
|
| 354 |
explainer = shap.KernelExplainer(VAEWrapper(model, device).predict, X_subset)
|
| 355 |
shap_values = explainer.shap_values(X_subset, nsamples=100)
|
| 356 |
shap_importance = pd.DataFrame({
|
| 357 |
'Biomarqueur': feature_names[:len(np.mean(np.abs(shap_values[0]), axis=0))],
|
| 358 |
'Importance SHAP': np.mean(np.abs(shap_values[0]), axis=0)
|
| 359 |
}).sort_values('Importance SHAP', ascending=False)
|
| 360 |
-
|
| 361 |
-
#
|
| 362 |
fig, ax = plt.subplots(figsize=(12, 8))
|
| 363 |
sns.barplot(data=shap_importance.head(20), x='Importance SHAP', y='Biomarqueur', palette='Set2')
|
| 364 |
plt.title('Top 20 Biomarqueurs par Importance SHAP')
|
| 365 |
st.pyplot(fig)
|
| 366 |
plt.close()
|
| 367 |
-
|
| 368 |
-
st.subheader("Biomarqueurs
|
| 369 |
st.dataframe(shap_importance.head(20))
|
| 370 |
|
| 371 |
-
#
|
| 372 |
-
elif page == "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 373 |
st.header("Résumé des Résultats")
|
| 374 |
if 'umap_df' not in st.session_state:
|
| 375 |
-
st.warning("
|
| 376 |
else:
|
| 377 |
st.subheader("Scores de Risque")
|
| 378 |
st.dataframe(st.session_state['umap_df'][['Cluster', 'Status', 'Score de Risque (%)']])
|
| 379 |
-
st.subheader("Téléchargement")
|
| 380 |
csv = st.session_state['umap_df'].to_csv()
|
| 381 |
st.download_button(
|
| 382 |
label="Télécharger les Résultats (CSV)",
|
| 383 |
data=csv,
|
| 384 |
-
file_name="
|
| 385 |
mime="text/csv"
|
| 386 |
)
|
|
|
|
| 9 |
import umap
|
| 10 |
import shap
|
| 11 |
import logging
|
|
|
|
| 12 |
import seaborn as sns
|
| 13 |
import matplotlib.pyplot as plt
|
| 14 |
from sklearn.decomposition import PCA
|
| 15 |
from sklearn.cluster import KMeans
|
| 16 |
from sklearn.ensemble import RandomForestClassifier
|
| 17 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 18 |
import warnings
|
| 19 |
+
from datetime import datetime
|
| 20 |
+
import io
|
| 21 |
|
| 22 |
warnings.filterwarnings('ignore')
|
| 23 |
|
|
|
|
| 25 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 26 |
|
| 27 |
# Répertoire de sortie
|
| 28 |
+
output_dir = 'omics_analysis_output'
|
| 29 |
os.makedirs(output_dir, exist_ok=True)
|
| 30 |
|
| 31 |
# Biomarqueurs IRC
|
|
|
|
| 90 |
class_logits = self.fc_classify(z)
|
| 91 |
return outputs, z, mu, log_var, class_logits
|
| 92 |
|
| 93 |
+
# Fonction pour générer des recommandations avec BioBERT
|
| 94 |
+
def generate_recommendation_with_biobert(patient_data, patient_id, biomarkers, tokenizer, model):
|
| 95 |
+
# Structurer les données patient en texte
|
| 96 |
+
text = f"""
|
| 97 |
+
Patient: {patient_id}, {patient_data['sex']}, {patient_data['age']} ans.
|
| 98 |
+
Score de risque IRC: {patient_data['risk_score']:.1f}%.
|
| 99 |
+
Antécédents familiaux: IRC ({'Oui' if patient_data['family_history_irc'] else 'Non'}),
|
| 100 |
+
Diabète ({'Oui' if patient_data['family_history_diabetes'] else 'Non'}),
|
| 101 |
+
Hypertension ({'Oui' if patient_data['family_history_hypertension'] else 'Non'}).
|
| 102 |
+
Comorbidités: Diabète ({'Oui' if patient_data['diabetes'] else 'Non'}),
|
| 103 |
+
Hypertension ({'Oui' if patient_data['hypertension'] else 'Non'}).
|
| 104 |
+
Biomarqueurs clés: {', '.join(biomarkers[:3])}.
|
| 105 |
+
"""
|
| 106 |
+
|
| 107 |
+
# Tokeniser le texte
|
| 108 |
+
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
|
| 109 |
+
|
| 110 |
+
# Simuler une classification avec BioBERT
|
| 111 |
+
with torch.no_grad():
|
| 112 |
+
outputs = model(**inputs)
|
| 113 |
+
logits = outputs.logits
|
| 114 |
+
prediction = torch.argmax(logits, dim=1).item()
|
| 115 |
+
|
| 116 |
+
# Mapper la prédiction à une recommandation (simulant un fine-tuning)
|
| 117 |
+
if prediction == 0: # Risque faible
|
| 118 |
+
advice = f"""
|
| 119 |
+
**Patient {patient_id} : Risque Faible ({patient_data['risk_score']:.1f}%)**
|
| 120 |
+
- **État** : Faible probabilité de progression vers l'IRC.
|
| 121 |
+
- **Recommandations** :
|
| 122 |
+
- Adopter une alimentation équilibrée, faible en sel (<2g/jour), riche en fruits et légumes.
|
| 123 |
+
- Maintenir une activité physique modérée (30 min/jour, 5 jours/semaine).
|
| 124 |
+
- Surveiller les biomarqueurs : {', '.join(biomarkers[:2])}.
|
| 125 |
+
- Hydratation adéquate (1,5-2L d’eau/jour).
|
| 126 |
+
- **Suivi** : Bilan rénal annuel.
|
| 127 |
+
"""
|
| 128 |
+
elif prediction == 1: # Risque modéré
|
| 129 |
+
advice = f"""
|
| 130 |
+
**Patient {patient_id} : Risque Modéré ({patient_data['risk_score']:.1f}%)**
|
| 131 |
+
- **État** : Risque intermédiaire de progression vers l'IRC.
|
| 132 |
+
- **Recommandations** :
|
| 133 |
+
- Consulter un néphrologue pour évaluer {', '.join(biomarkers[:2])}.
|
| 134 |
+
- Régime alimentaire strict : réduire les protéines animales, sodium (<1,5g/jour).
|
| 135 |
+
- Contrôler la pression artérielle (<130/80 mmHg) et la glycémie si diabétique.
|
| 136 |
+
- Éviter les AINS sauf prescription médicale.
|
| 137 |
+
- **Suivi** : Bilan rénal trimestriel.
|
| 138 |
+
"""
|
| 139 |
+
else: # Risque élevé
|
| 140 |
+
advice = f"""
|
| 141 |
+
**Patient {patient_id} : Risque Élevé ({patient_data['risk_score']:.1f}%)**
|
| 142 |
+
- **État** : Forte probabilité de progression vers l'IRC.
|
| 143 |
+
- **Recommandations** :
|
| 144 |
+
- Consultation urgente avec un néphrologue.
|
| 145 |
+
- Analyse des biomarqueurs : {', '.join(biomarkers[:3])}.
|
| 146 |
+
- Régime rénal strict : faible en potassium, phosphore, sodium.
|
| 147 |
+
- Envisager une thérapie (ex. : inhibiteurs de l’ECA, diurétiques).
|
| 148 |
+
- Surveillance hebdomadaire de la créatinine et du DFG.
|
| 149 |
+
- **Suivi** : Plan thérapeutique dans 1 semaine.
|
| 150 |
+
"""
|
| 151 |
+
advice += "\n**Note** : Ces recommandations doivent être validées par un médecin."
|
| 152 |
+
return advice
|
| 153 |
+
|
| 154 |
# Configuration de Streamlit
|
| 155 |
+
st.set_page_config(page_title="Analyse Multi-Omique IRC", layout="wide")
|
| 156 |
st.markdown("""
|
| 157 |
<style>
|
| 158 |
+
.main {background-color: #1e1e1e; color: #ffffff; font-family: 'Roboto', sans-serif;}
|
| 159 |
+
.stButton>button {
|
| 160 |
+
background-color: #4CAF50; color: white; border-radius: 12px; padding: 12px 24px;
|
| 161 |
+
transition: all 0.3s ease; border: none; font-weight: bold; font-size: 16px;
|
| 162 |
+
}
|
| 163 |
+
.stButton>button:hover {background-color: #45a049; transform: scale(1.05);}
|
| 164 |
+
.stSelectbox, .stFileUploader, .stTextInput, .stNumberInput, .stCheckbox {
|
| 165 |
+
background-color: #2e2e2e; color: #ffffff; border-radius: 8px; padding: 8px;
|
| 166 |
+
}
|
| 167 |
+
.sidebar .sidebar-content {background-color: #2e2e2e; color: #ffffff;}
|
| 168 |
+
.stMarkdown h1, h2, h3 {color: #4CAF50; font-weight: bold;}
|
| 169 |
+
.chat-message {padding: 12px; border-radius: 10px; margin-bottom: 12px; max-width: 80%;}
|
| 170 |
+
.user-message {background-color: #4CAF50; color: white; margin-left: 20%;}
|
| 171 |
+
.bot-message {background-color: #333333; color: white; margin-right: 20%;}
|
| 172 |
+
.stPlotlyChart {border-radius: 10px; overflow: hidden;}
|
| 173 |
</style>
|
| 174 |
""", unsafe_allow_html=True)
|
| 175 |
|
| 176 |
+
st.title("Plateforme d’Analyse Multi-Omique pour l’IRC")
|
| 177 |
+
st.markdown("""
|
| 178 |
+
**Développée par Ngoue David, Master 2 Intelligence Artificielle et Big Data**
|
| 179 |
+
Hôpital Général de Yaoundé | Approche innovante pour le suivi et la thérapie personnalisée de l’insuffisance rénale chronique.
|
| 180 |
+
""")
|
| 181 |
|
| 182 |
# Menu latéral
|
| 183 |
st.sidebar.header("Navigation")
|
| 184 |
+
page = st.sidebar.radio("Étapes du Projet", [
|
| 185 |
+
"Présentation",
|
| 186 |
"Chargement des Données",
|
| 187 |
"Analyse Exploratoire",
|
| 188 |
+
"Clustering",
|
| 189 |
"Scores de Risque",
|
| 190 |
"Analyse SHAP",
|
| 191 |
+
"Conseiller Médical",
|
| 192 |
+
"Résumé"
|
| 193 |
])
|
| 194 |
|
| 195 |
+
# Chargement de BioBERT
|
| 196 |
+
@st.cache_resource
|
| 197 |
+
def load_biobert():
|
| 198 |
+
tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-v1.1")
|
| 199 |
+
model = AutoModelForSequenceClassification.from_pretrained("dmis-lab/biobert-v1.1", num_labels=3)
|
| 200 |
+
return tokenizer, model
|
| 201 |
+
|
| 202 |
+
biobert_tokenizer, biobert_model = load_biobert()
|
| 203 |
+
|
| 204 |
+
# Présentation
|
| 205 |
+
if page == "Présentation":
|
| 206 |
+
st.header("Contexte et Innovation")
|
| 207 |
st.markdown("""
|
| 208 |
+
**Projet : Thérapie Personnalisée de l’IRC via une Approche Multi-Omique**
|
| 209 |
+
Réalisé par Ngoue David dans le cadre du Master 2 Intelligence Artificielle et Big Data, ce projet ambitionne de transformer la gestion de l’insuffisance rénale chronique (IRC) à l’Hôpital Général de Yaoundé. En exploitant des données multi-omiques (génomique, transcriptomique, protéomique, métabolomique) à travers une architecture de transformers hybrides (OmicsVAE), nous proposons :
|
| 210 |
+
- **Prédiction précise** des risques de progression de l’IRC.
|
| 211 |
+
- **Thérapies sur mesure** basées sur les profils moléculaires des patients.
|
| 212 |
+
- **Suivi optimisé** grâce à un conseiller médical virtuel intelligent basé sur BioBERT.
|
| 213 |
+
|
| 214 |
+
**Impact** : Cette solution renforce la médecine de précision au Cameroun, améliore les résultats cliniques et réduit les coûts pour les patients atteints d’IRC.
|
| 215 |
+
**Explorez** les analyses et interagissez avec le conseiller médical via le menu latéral.
|
|
|
|
| 216 |
""")
|
| 217 |
|
| 218 |
# Chargement des Données
|
| 219 |
elif page == "Chargement des Données":
|
| 220 |
+
st.header("Chargement des Données et du Modèle")
|
| 221 |
+
st.markdown("Importez les fichiers omiques et le modèle VAE pré-entraîné pour initialiser l’analyse.")
|
| 222 |
|
|
|
|
| 223 |
uploaded_files = {}
|
| 224 |
omics_types = ['génomique', 'transcriptomique', 'protéomique', 'métabolomique']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
for omic in omics_types:
|
| 226 |
+
uploaded_file = st.file_uploader(f"Données {omic} (CSV)", type="csv", key=omic)
|
| 227 |
if uploaded_file:
|
| 228 |
+
uploaded_files[omic] = uploaded_file
|
|
|
|
| 229 |
|
| 230 |
+
model_file = st.file_uploader("Modèle VAE (PTH)", type="pth")
|
|
|
|
| 231 |
if model_file:
|
| 232 |
with open(os.path.join(output_dir, 'omics_vae_best_hyperparams.pth'), 'wb') as f:
|
| 233 |
f.write(model_file.getvalue())
|
| 234 |
|
| 235 |
+
if st.button("Initialiser l’Analyse") and len(uploaded_files) == len(omics_types) and model_file:
|
| 236 |
try:
|
|
|
|
| 237 |
data_dict = {}
|
| 238 |
+
for omic, file in uploaded_files.items():
|
| 239 |
+
df = pd.read_csv(file, index_col='Patient_ID')
|
| 240 |
data_dict[omic] = df.drop(columns=['Status'])
|
| 241 |
+
labels = pd.read_csv(list(uploaded_files.values())[0], index_col='Patient_ID')['Status']
|
| 242 |
le = LabelEncoder()
|
| 243 |
encoded_labels = pd.Series(le.fit_transform(labels), index=labels.index, name='Status')
|
| 244 |
common_samples = data_dict['génomique'].index
|
|
|
|
| 250 |
st.session_state['label_encoder'] = le
|
| 251 |
st.session_state['common_samples'] = common_samples
|
| 252 |
|
|
|
|
| 253 |
input_dims = [data_dict[omic].shape[1] for omic in data_dict]
|
| 254 |
model = OmicsVAE(
|
| 255 |
input_dims=input_dims,
|
|
|
|
| 262 |
st.session_state['input_dims'] = input_dims
|
| 263 |
st.success("Données et modèle chargés avec succès !")
|
| 264 |
except Exception as e:
|
| 265 |
+
st.error(f"Erreur : {str(e)}")
|
| 266 |
|
| 267 |
# Analyse Exploratoire
|
| 268 |
elif page == "Analyse Exploratoire":
|
| 269 |
st.header("Analyse Exploratoire des Données")
|
| 270 |
if 'data_dict' not in st.session_state:
|
| 271 |
+
st.warning("Chargez les données d'abord.")
|
| 272 |
else:
|
| 273 |
data_dict = st.session_state['data_dict']
|
| 274 |
labels = st.session_state['labels']
|
| 275 |
+
omic = st.selectbox("Type omique", list(data_dict.keys()), help="Sélectionnez une catégorie omique à explorer")
|
| 276 |
biomarkers = [col for col in data_dict[omic].columns if col in irc_biomarkers]
|
| 277 |
|
|
|
|
| 278 |
if biomarkers:
|
| 279 |
+
st.subheader(f"Matrice de Corrélation ({omic})")
|
| 280 |
corr_matrix = data_dict[omic][biomarkers].corr()
|
| 281 |
fig = go.Figure(data=go.Heatmap(
|
| 282 |
z=corr_matrix.values,
|
|
|
|
| 291 |
fig.update_layout(title=f'Matrice de Corrélation ({omic})', template='plotly_dark')
|
| 292 |
st.plotly_chart(fig, use_container_width=True)
|
| 293 |
|
|
|
|
| 294 |
st.subheader(f"Projection PCA 3D ({omic})")
|
| 295 |
pca = PCA(n_components=3)
|
| 296 |
pca_result = pca.fit_transform(data_dict[omic])
|
|
|
|
| 312 |
fig.update_traces(marker=dict(size=5))
|
| 313 |
st.plotly_chart(fig, use_container_width=True)
|
| 314 |
|
| 315 |
+
# Clustering
|
| 316 |
+
elif page == "Clustering":
|
| 317 |
+
st.header("Analyse de Clustering")
|
| 318 |
if 'data_dict' not in st.session_state:
|
| 319 |
+
st.warning("Chargez les données d'abord.")
|
| 320 |
else:
|
| 321 |
data_dict = st.session_state['data_dict']
|
| 322 |
labels = st.session_state['labels']
|
| 323 |
label_encoder = st.session_state['label_encoder']
|
| 324 |
+
n_clusters = st.slider("Nombre de clusters", 2, 10, 5, help="Ajustez le nombre de clusters pour KMeans")
|
| 325 |
+
if st.button("Effectuer le Clustering"):
|
| 326 |
combined_data = pd.concat([data_dict[omic] for omic in data_dict], axis=1)
|
| 327 |
kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
|
| 328 |
umap_reducer = umap.UMAP(n_components=2, n_neighbors=15, min_dist=0.1, random_state=42)
|
|
|
|
| 334 |
st.session_state['kmeans'] = kmeans
|
| 335 |
st.session_state['umap_embedding'] = umap_embedding
|
| 336 |
|
|
|
|
| 337 |
fig = px.scatter(
|
| 338 |
umap_df, x='UMAP1', y='UMAP2', color='Cluster', symbol='Status',
|
| 339 |
title='Projection UMAP avec Clusters KMeans',
|
|
|
|
| 346 |
|
| 347 |
# Scores de Risque
|
| 348 |
elif page == "Scores de Risque":
|
| 349 |
+
st.header("Scores de Risque IRC")
|
| 350 |
if 'umap_df' not in st.session_state or 'data_dict' not in st.session_state:
|
| 351 |
+
st.warning("Effectuez le clustering et chargez les données d'abord.")
|
| 352 |
else:
|
| 353 |
umap_df = st.session_state['umap_df']
|
| 354 |
data_dict = st.session_state['data_dict']
|
|
|
|
| 356 |
label_encoder = st.session_state['label_encoder']
|
| 357 |
kmeans = st.session_state['kmeans']
|
| 358 |
umap_embedding = st.session_state['umap_embedding']
|
| 359 |
+
|
| 360 |
+
if st.button("Calculer les Scores"):
|
| 361 |
cluster_centers = kmeans.cluster_centers_
|
| 362 |
distances = np.zeros(len(umap_embedding))
|
| 363 |
for i, emb in enumerate(umap_embedding):
|
|
|
|
| 380 |
umap_df['Score de Risque (%)'] = final_risk
|
| 381 |
st.session_state['umap_df'] = umap_df
|
| 382 |
|
|
|
|
| 383 |
fig = px.scatter(
|
| 384 |
umap_df, x='UMAP1', y='UMAP2', color='Score de Risque (%)', symbol='Status',
|
| 385 |
title='Projection UMAP avec Scores de Risque IRC (%)',
|
|
|
|
| 392 |
|
| 393 |
# Analyse SHAP
|
| 394 |
elif page == "Analyse SHAP":
|
| 395 |
+
st.header("Analyse SHAP des Biomarqueurs")
|
| 396 |
if 'model' not in st.session_state or 'data_dict' not in st.session_state:
|
| 397 |
+
st.warning("Chargez le modèle et les données d'abord.")
|
| 398 |
else:
|
| 399 |
model = st.session_state['model']
|
| 400 |
data_dict = st.session_state['data_dict']
|
| 401 |
input_dims = st.session_state['input_dims']
|
| 402 |
combined_data = pd.concat([data_dict[omic] for omic in data_dict], axis=1)
|
| 403 |
feature_names = sum([data_dict[omic].columns.tolist() for omic in data_dict], [])
|
| 404 |
+
|
| 405 |
+
if st.button("Lancer l’Analyse SHAP"):
|
| 406 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 407 |
X_concat = combined_data.values
|
| 408 |
n_samples = min(100, X_concat.shape[0])
|
| 409 |
X_subset = X_concat[:n_samples]
|
| 410 |
+
|
| 411 |
class VAEWrapper:
|
| 412 |
def __init__(self, model, device):
|
| 413 |
self.model = model
|
|
|
|
| 421 |
with torch.no_grad():
|
| 422 |
_, z, _, _, _ = self.model(X_tensors)
|
| 423 |
return torch.norm(z, dim=1).cpu().numpy()
|
| 424 |
+
|
| 425 |
explainer = shap.KernelExplainer(VAEWrapper(model, device).predict, X_subset)
|
| 426 |
shap_values = explainer.shap_values(X_subset, nsamples=100)
|
| 427 |
shap_importance = pd.DataFrame({
|
| 428 |
'Biomarqueur': feature_names[:len(np.mean(np.abs(shap_values[0]), axis=0))],
|
| 429 |
'Importance SHAP': np.mean(np.abs(shap_values[0]), axis=0)
|
| 430 |
}).sort_values('Importance SHAP', ascending=False)
|
| 431 |
+
|
| 432 |
+
# Afficher le graphique directement sans sauvegarde
|
| 433 |
fig, ax = plt.subplots(figsize=(12, 8))
|
| 434 |
sns.barplot(data=shap_importance.head(20), x='Importance SHAP', y='Biomarqueur', palette='Set2')
|
| 435 |
plt.title('Top 20 Biomarqueurs par Importance SHAP')
|
| 436 |
st.pyplot(fig)
|
| 437 |
plt.close()
|
| 438 |
+
|
| 439 |
+
st.subheader("Biomarqueurs Clés")
|
| 440 |
st.dataframe(shap_importance.head(20))
|
| 441 |
|
| 442 |
+
# Conseiller Médical
|
| 443 |
+
elif page == "Conseiller Médical":
|
| 444 |
+
st.header("Conseiller Médical Virtuel")
|
| 445 |
+
st.markdown("Interagissez avec notre assistant basé sur BioBERT pour obtenir des recommandations personnalisées.")
|
| 446 |
+
|
| 447 |
+
if 'umap_df' not in st.session_state:
|
| 448 |
+
st.warning("Calculez les scores de risque d'abord.")
|
| 449 |
+
else:
|
| 450 |
+
umap_df = st.session_state['umap_df']
|
| 451 |
+
st.subheader("Informations du Patient")
|
| 452 |
+
with st.form("patient_form"):
|
| 453 |
+
patient_id = st.text_input("ID du Patient", help="Ex. Patient_001")
|
| 454 |
+
age = st.number_input("Âge", min_value=18, max_value=120, value=30)
|
| 455 |
+
sex = st.selectbox("Sexe", ["Homme", "Femme"])
|
| 456 |
+
family_history_irc = st.checkbox("Antécédents familiaux d’IRC")
|
| 457 |
+
family_history_diabetes = st.checkbox("Antécédents familiaux de diabète")
|
| 458 |
+
family_history_hypertension = st.checkbox("Antécédents familiaux d’hypertension")
|
| 459 |
+
diabetes = st.checkbox("Diabète actuel")
|
| 460 |
+
hypertension = st.checkbox("Hypertension actuelle")
|
| 461 |
+
submitted = st.form_submit_button("Soumettre")
|
| 462 |
+
|
| 463 |
+
if submitted and patient_id in umap_df.index:
|
| 464 |
+
patient_data = {
|
| 465 |
+
'risk_score': umap_df.loc[patient_id, 'Score de Risque (%)'],
|
| 466 |
+
'age': age,
|
| 467 |
+
'sex': sex,
|
| 468 |
+
'family_history_irc': family_history_irc,
|
| 469 |
+
'family_history_diabetes': family_history_diabetes,
|
| 470 |
+
'family_history_hypertension': family_history_hypertension,
|
| 471 |
+
'diabetes': diabetes,
|
| 472 |
+
'hypertension': hypertension
|
| 473 |
+
}
|
| 474 |
+
advice = generate_recommendation_with_biobert(patient_data, patient_id, irc_biomarkers, biobert_tokenizer, biobert_model)
|
| 475 |
+
st.markdown(f"<div class='bot-message'>{advice}</div>", unsafe_allow_html=True)
|
| 476 |
+
elif submitted:
|
| 477 |
+
st.error("ID du patient invalide.")
|
| 478 |
+
|
| 479 |
+
# Résumé
|
| 480 |
+
elif page == "Résumé":
|
| 481 |
st.header("Résumé des Résultats")
|
| 482 |
if 'umap_df' not in st.session_state:
|
| 483 |
+
st.warning("Complétez les étapes précédentes.")
|
| 484 |
else:
|
| 485 |
st.subheader("Scores de Risque")
|
| 486 |
st.dataframe(st.session_state['umap_df'][['Cluster', 'Status', 'Score de Risque (%)']])
|
|
|
|
| 487 |
csv = st.session_state['umap_df'].to_csv()
|
| 488 |
st.download_button(
|
| 489 |
label="Télécharger les Résultats (CSV)",
|
| 490 |
data=csv,
|
| 491 |
+
file_name=f"resultats_irc_{datetime.now().strftime('%Y%m%d')}.csv",
|
| 492 |
mime="text/csv"
|
| 493 |
)
|