Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from sklearn.preprocessing import LabelEncoder | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| import umap | |
| import shap | |
| import seaborn as sns | |
| import matplotlib.pyplot as plt | |
| from sklearn.decomposition import PCA | |
| from sklearn.cluster import KMeans | |
| from sklearn.ensemble import RandomForestClassifier | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification | |
| from reportlab.lib.pagesizes import A4 | |
| from reportlab.pdfgen import canvas | |
| from reportlab.lib import colors | |
| import io | |
| import logging | |
| from datetime import datetime | |
| import warnings | |
| import os | |
| warnings.filterwarnings('ignore') | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| # Chemin du modèle dans l’image Docker | |
| model_path = 'src/omics_vae_best_hyperparams.pth' | |
| # Biomarqueurs IRC | |
| irc_biomarkers = [ | |
| 'UMOD_rs12917707', 'APOL1_rs73885319', 'MYH9_rs4821480', | |
| 'HAVCR1', 'TGFB1', 'IL6', 'HNF4A', 'NPHS1', 'AQP2', | |
| 'B2MG', 'Albumin', 'NGAL', 'Cystatin_C', 'Uromodulin', 'KLOTHO', | |
| 'Kynurenine', 'Indoxyl_Sulfate', 'Creatinine', '5-MTP' | |
| ] | |
| # Définition du modèle OmicsVAE | |
| class OmicsVAE(nn.Module): | |
| def __init__(self, input_dims, hidden_dim=256, latent_dim=64, num_heads=8, num_layers=3, dropout=0.4, num_classes=2): | |
| super(OmicsVAE, self).__init__() | |
| self.input_dims = input_dims | |
| self.hidden_dim = hidden_dim | |
| self.latent_dim = latent_dim | |
| self.num_omics = len(input_dims) | |
| self.num_classes = num_classes | |
| self.input_projections = nn.ModuleList([nn.Linear(dim, hidden_dim) for dim in input_dims]) | |
| self.positional_encoding = self.create_positional_encoding(hidden_dim, max_len=self.num_omics) | |
| transformer_layer = nn.TransformerEncoderLayer( | |
| d_model=hidden_dim, nhead=num_heads, | |
| dim_feedforward=hidden_dim * 4, dropout=dropout, | |
| batch_first=True | |
| ) | |
| self.transformer_encoder = nn.TransformerEncoder(transformer_layer, num_layers=num_layers) | |
| self.fc_mu = nn.Linear(hidden_dim * self.num_omics, latent_dim) | |
| self.fc_log_var = nn.Linear(hidden_dim * self.num_omics, latent_dim) | |
| self.fc_decode = nn.Linear(latent_dim, hidden_dim * self.num_omics) | |
| self.decoder_projections = nn.ModuleList([nn.Linear(hidden_dim, dim) for dim in input_dims]) | |
| self.fc_classify = nn.Linear(latent_dim, num_classes) | |
| def create_positional_encoding(self, d_model, max_len): | |
| pe = torch.zeros(max_len, d_model) | |
| position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) | |
| div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)) | |
| pe[:, 0::2] = torch.sin(position * div_term) | |
| pe[:, 1::2] = torch.cos(position * div_term) | |
| return pe.unsqueeze(0) | |
| def reparameterize(self, mu, log_var): | |
| std = torch.exp(0.5 * log_var) | |
| eps = torch.randn_like(std) | |
| return mu + eps * std | |
| def forward(self, x_list): | |
| encoded = [] | |
| for i, x in enumerate(x_list): | |
| proj = self.input_projections[i](x) | |
| pe = self.positional_encoding[:, i, :].to(x.device) | |
| proj = proj + pe.expand(x.size(0), -1) | |
| encoded.append(proj.unsqueeze(1)) | |
| encoded = torch.cat(encoded, dim=1) | |
| transformer_out = self.transformer_encoder(encoded) | |
| transformer_out = transformer_out.contiguous().view(transformer_out.size(0), -1) | |
| mu = self.fc_mu(transformer_out) | |
| log_var = self.fc_log_var(transformer_out) | |
| z = self.reparameterize(mu, log_var) | |
| decoded = self.fc_decode(z).view(z.size(0), self.num_omics, self.hidden_dim) | |
| outputs = [self.decoder_projections[i](decoded[:, i, :]) for i in range(self.num_omics)] | |
| class_logits = self.fc_classify(z) | |
| return outputs, z, mu, log_var, class_logits | |
| # Fonction pour générer des recommandations avec BioBERT | |
| def generate_recommendation_with_biobert(patient_data, patient_id, biomarkers, tokenizer, model, data_dict): | |
| # Extraire les valeurs réelles des biomarqueurs | |
| biomarker_values = [] | |
| for omic in data_dict: | |
| for biomarker in biomarkers[:3]: | |
| if biomarker in data_dict[omic].columns: | |
| value = data_dict[omic].loc[patient_id, biomarker] | |
| biomarker_values.append(f"{biomarker}: {value:.2f}") | |
| # Structurer le texte d’entrée | |
| text = f""" | |
| Patient: {patient_id}, {patient_data['sex']}, {patient_data['age']} ans. | |
| Score de risque IRC: {patient_data['risk_score']:.1f}%. | |
| Antécédents familiaux: IRC ({'Oui' if patient_data['family_history_irc'] else 'Non'}), | |
| Diabète ({'Oui' if patient_data['family_history_diabetes'] else 'Non'}), | |
| Hypertension ({'Oui' if patient_data['family_history_hypertension'] else 'Non'}). | |
| Comorbidités: Diabète ({'Oui' if patient_data['diabetes'] else 'Non'}), | |
| Hypertension ({'Oui' if patient_data['hypertension'] else 'Non'}). | |
| Biomarqueurs: {', '.join(biomarker_values)}. | |
| """ | |
| # Tokenisation et classification | |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| prediction = torch.argmax(logits, dim=1).item() | |
| # Génération dynamique des recommandations | |
| base_advice = { | |
| 0: { | |
| 'title': f"Patient {patient_id} : Risque Faible ({patient_data['risk_score']:.1f}%)", | |
| 'state': "Faible probabilité de progression vers l'IRC.", | |
| 'lifestyle': "Adopter une alimentation faible en sel (<2g/jour), riche en fruits et légumes. Activité physique modérée (30 min/jour, 5 jours/semaine).", | |
| 'monitoring': f"Bilan rénal annuel, surveiller {', '.join(biomarker_values[:2])}. Hydratation adéquate (1,5-2L/jour).", | |
| 'therapy': "Aucune thérapie spécifique requise." | |
| }, | |
| 1: { | |
| 'title': f"Patient {patient_id} : Risque Modéré ({patient_data['risk_score']:.1f}%)", | |
| 'state': "Risque intermédiaire de progression vers l'IRC.", | |
| 'lifestyle': "Régime strict : réduire protéines animales, sodium (<1,5g/jour). Contrôler pression artérielle (<130/80 mmHg).", | |
| 'monitoring': f"Consultation néphrologue trimestrielle, évaluer {', '.join(biomarker_values[:2])}. Éviter AINS sauf prescription.", | |
| 'therapy': "Considérer un contrôle glycémique strict si diabétique." | |
| }, | |
| 2: { | |
| 'title': f"Patient {patient_id} : Risque Élevé ({patient_data['risk_score']:.1f}%)", | |
| 'state': "Forte probabilité de progression vers l'IRC.", | |
| 'lifestyle': "Régime rénal strict : faible en potassium, phosphore, sodium.", | |
| 'monitoring': f"Consultation néphrologue urgente, surveillance hebdomadaire (créatinine, DFG). Analyser {', '.join(biomarker_values[:3])}. ", | |
| 'therapy': "Envisager inhibiteurs de l’ECA ou diurétiques, après évaluation." | |
| } | |
| } | |
| advice = base_advice[prediction] | |
| if patient_data['diabetes']: | |
| advice['therapy'] += " Contrôle strict de la glycémie (HbA1c <7%)." | |
| if patient_data['hypertension']: | |
| advice['therapy'] += " Médicaments antihypertenseurs (ex. : losartan) sous supervision." | |
| if patient_data['family_history_irc']: | |
| advice['monitoring'] += " Surveillance accrue des antécédents familiaux." | |
| for biomarker in biomarker_values: | |
| if "Creatinine" in biomarker and float(biomarker.split(":")[1]) > 1.5: | |
| advice['monitoring'] += f" Attention : Créatinine élevée ({biomarker.split(':')[1]} mg/dL), suivi rapproché recommandé." | |
| formatted_advice = f""" | |
| **{advice['title']}** | |
| - **État** : {advice['state']} | |
| - **Mode de vie** : {advice['lifestyle']} | |
| - **Suivi** : {advice['monitoring']} | |
| - **Thérapie** : {advice['therapy']} | |
| **Avertissement** : Ces recommandations doivent être validées par un médecin. | |
| """ | |
| return formatted_advice | |
| # Fonction pour générer un rapport PDF en mémoire | |
| def generate_pdf_report(patient_id, patient_data, advice, umap_df, shap_importance=None): | |
| buffer = io.BytesIO() | |
| c = canvas.Canvas(buffer, pagesize=A4) | |
| c.setFont("Helvetica-Bold", 16) | |
| c.drawString(50, 800, f"Rapport IRC - Patient {patient_id}") | |
| c.setFont("Helvetica", 12) | |
| c.drawString(50, 770, f"Date: {datetime.now().strftime('%Y-%m-%d')}") | |
| # Informations patient | |
| c.setFont("Helvetica-Bold", 14) | |
| c.drawString(50, 740, "Informations du Patient") | |
| c.setFont("Helvetica", 12) | |
| y = 720 | |
| c.drawString(50, y, f"Âge: {patient_data['age']} ans") | |
| c.drawString(50, y-20, f"Sexe: {patient_data['sex']}") | |
| c.drawString(50, y-40, f"Score de risque: {patient_data['risk_score']:.1f}%") | |
| c.drawString(50, y-60, f"Antécédents familiaux: IRC ({'Oui' if patient_data['family_history_irc'] else 'Non'}), " | |
| f"Diabète ({'Oui' if patient_data['family_history_diabetes'] else 'Non'}), " | |
| f"Hypertension ({'Oui' if patient_data['family_history_hypertension'] else 'Non'})") | |
| c.drawString(50, y-80, f"Comorbidités: Diabète ({'Oui' if patient_data['diabetes'] else 'Non'}), " | |
| f"Hypertension ({'Oui' if patient_data['hypertension'] else 'Non'})") | |
| # Recommandations | |
| c.setFont("Helvetica-Bold", 14) | |
| c.drawString(50, y-110, "Recommandations") | |
| c.setFont("Helvetica", 12) | |
| text_object = c.beginText(50, y-130) | |
| text_object.setLeading(14) | |
| for line in advice.split('\n'): | |
| text_object.textLine(line) | |
| c.drawText(text_object) | |
| # Graphique UMAP | |
| if 'umap_df' in st.session_state: | |
| fig = px.scatter( | |
| umap_df, x='UMAP1', y='UMAP2', color='Score de Risque (%)', symbol='Status', | |
| title='Projection UMAP', | |
| color_continuous_scale='RdYlGn_r', | |
| template='plotly_dark' | |
| ) | |
| fig.update_traces(marker=dict(size=12)) | |
| img_buffer = io.BytesIO() | |
| fig.write_image(img_buffer, format='png', width=500, height=300) | |
| img_buffer.seek(0) | |
| c.drawImage(img_buffer, 50, y-400, width=500, height=300) | |
| # Graphique SHAP | |
| if shap_importance is not None: | |
| c.showPage() | |
| c.setFont("Helvetica-Bold", 14) | |
| c.drawString(50, 800, "Analyse SHAP") | |
| fig, ax = plt.subplots(figsize=(6, 4)) | |
| sns.barplot(data=shap_importance.head(10), x='Importance SHAP', y='Biomarqueur', palette='Set2') | |
| plt.title('Top 10 Biomarqueurs') | |
| img_buffer = io.BytesIO() | |
| fig.savefig(img_buffer, format='png', bbox_inches='tight') | |
| plt.close() | |
| img_buffer.seek(0) | |
| c.drawImage(img_buffer, 50, 700, width=500, height=300) | |
| c.save() | |
| buffer.seek(0) | |
| return buffer | |
| # Configuration de Streamlit | |
| st.set_page_config(page_title="Analyse Multi-Omique IRC", layout="wide") | |
| st.markdown(""" | |
| <style> | |
| .main {background-color: #1e1e1e; color: #ffffff; font-family: 'Roboto', sans-serif;} | |
| .stButton>button { | |
| background-color: #4CAF50; color: white; border-radius: 12px; padding: 12px 24px; | |
| transition: all 0.3s ease; border: none; font-weight: bold; font-size: 16px; | |
| } | |
| .stButton>button:hover {background-color: #45a049; transform: scale(1.05);} | |
| .stSelectbox, .stTextInput, .stNumberInput, .stCheckbox, .stFileUploader { | |
| background-color: #2e2e2e; color: #ffffff; border-radius: 8px; padding: 8px; | |
| } | |
| .sidebar .sidebar-content {background-color: #2e2e2e; color: #ffffff;} | |
| .stMarkdown h1, h2, h3 {color: #4CAF50; font-weight: bold;} | |
| .chat-message {padding: 12px; border-radius: 10px; margin-bottom: 12px; max-width: 80%;} | |
| .bot-message {background-color: #333333; color: white; margin-right: 20%;} | |
| .stPlotlyChart {border-radius: 10px; overflow: hidden;} | |
| </style> | |
| """, unsafe_allow_html=True) | |
| st.title("Plateforme d’Analyse Multi-Omique pour l’IRC") | |
| st.markdown("**Ngoue David, M2 Intelligence Artificielle et Big Data** | Hôpital Général de Yaoundé") | |
| # Menu latéral | |
| st.sidebar.header("Navigation") | |
| page = st.sidebar.radio("Étapes", [ | |
| "Présentation", | |
| "Chargement des Données", | |
| "Analyse Exploratoire", | |
| "Clustering", | |
| "Scores de Risque", | |
| "Analyse SHAP", | |
| "Conseiller Médical", | |
| "Résumé" | |
| ]) | |
| # Chargement de BioBERT | |
| def load_biobert(): | |
| tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-v1.1") | |
| model = AutoModelForSequenceClassification.from_pretrained("dmis-lab/biobert-v1.1", num_labels=3) | |
| return tokenizer, model | |
| biobert_tokenizer, biobert_model = load_biobert() | |
| # Chargement du modèle VAE | |
| def load_model(input_dims, num_classes): | |
| model = OmicsVAE(input_dims=input_dims, num_classes=num_classes) | |
| try: | |
| model.load_state_dict(torch.load(model_path)) | |
| except FileNotFoundError: | |
| raise FileNotFoundError(f"Modèle {model_path} non trouvé dans l’image Docker.") | |
| model.eval() | |
| return model | |
| # Présentation | |
| if page == "Présentation": | |
| st.header("Contexte et Innovation") | |
| st.markdown(""" | |
| **Projet : Thérapie Personnalisée de l’IRC** | |
| Réalisé par Ngoue David, ce projet révolutionne la prise en charge de l’IRC à l’Hôpital Général de Yaoundé via une approche multi-omique. Une architecture de transformers hybrides (OmicsVAE) permet : | |
| - **Prédiction** des risques de progression de l’IRC. | |
| - **Thérapies sur mesure** basées sur les profils moléculaires. | |
| - **Suivi intelligent** avec un conseiller BioBERT. | |
| **Impact** : Médecine de précision pour le Cameroun, meilleurs résultats, coûts réduits. | |
| **Explorez** via le menu latéral. | |
| """) | |
| # Chargement des Données | |
| elif page == "Chargement des Données": | |
| st.header("Chargement des Données") | |
| st.markdown("Uploadez les fichiers omiques (CSV) pour l’analyse.") | |
| uploaded_files = {} | |
| omics_types = ['génomique', 'transcriptomique', 'protéomique', 'métabolomique'] | |
| for omic in omics_types: | |
| uploaded_file = st.file_uploader(f"Données {omic} (CSV)", type="csv", key=omic) | |
| if uploaded_file: | |
| uploaded_files[omic] = uploaded_file | |
| if st.button("Initialiser l’Analyse") and len(uploaded_files) == len(omics_types): | |
| try: | |
| data_dict = {} | |
| for omic, file in uploaded_files.items(): | |
| df = pd.read_csv(file, index_col='Patient_ID') | |
| if 'Status' not in df.columns: | |
| raise ValueError(f"Le fichier {omic} doit contenir une colonne 'Status'.") | |
| data_dict[omic] = df.drop(columns=['Status']) | |
| labels = pd.read_csv(uploaded_files['génomique'], index_col='Patient_ID')['Status'] | |
| le = LabelEncoder() | |
| encoded_labels = pd.Series(le.fit_transform(labels), index=labels.index, name='Status') | |
| common_samples = data_dict['génomique'].index | |
| for omic in data_dict: | |
| data_dict[omic] = data_dict[omic].loc[common_samples] | |
| labels = encoded_labels.loc[common_samples] | |
| input_dims = [data_dict[omic].shape[1] for omic in data_dict] | |
| model = load_model(input_dims, len(np.unique(encoded_labels))) | |
| st.session_state['data_dict'] = data_dict | |
| st.session_state['labels'] = labels | |
| st.session_state['label_encoder'] = le | |
| st.session_state['common_samples'] = common_samples | |
| st.session_state['model'] = model | |
| st.session_state['input_dims'] = input_dims | |
| st.success("Données et modèle chargés avec succès !") | |
| except Exception as e: | |
| st.error(f"Erreur : {str(e)}") | |
| # Analyse Exploratoire | |
| elif page == "Analyse Exploratoire": | |
| st.header("Analyse Exploratoire") | |
| if 'data_dict' not in st.session_state: | |
| st.warning("Chargez les données d'abord.") | |
| else: | |
| data_dict = st.session_state['data_dict'] | |
| labels = st.session_state['labels'] | |
| omic = st.selectbox("Type omique", list(data_dict.keys())) | |
| biomarkers = [col for col in data_dict[omic].columns if col in irc_biomarkers] | |
| if biomarkers: | |
| st.subheader(f"Matrice de Corrélation ({omic})") | |
| corr_matrix = data_dict[omic][biomarkers].corr() | |
| fig = go.Figure(data=go.Heatmap( | |
| z=corr_matrix.values, | |
| x=corr_matrix.columns, | |
| y=corr_matrix.columns, | |
| colorscale='Magma', | |
| zmin=-1, zmax=1, | |
| text=np.round(corr_matrix.values, 2), | |
| texttemplate="%{text}", | |
| hovertemplate='Biomarqueur 1: %{x}<br>Biomarqueur 2: %{y}<br>Corrélation: %{z:.2f}<extra></extra>' | |
| )) | |
| fig.update_layout(title=f'Matrice de Corrélation ({omic})', template='plotly_dark') | |
| st.plotly_chart(fig, use_container_width=True) | |
| st.subheader(f"Projection PCA 3D ({omic})") | |
| pca = PCA(n_components=3) | |
| pca_result = pca.fit_transform(data_dict[omic]) | |
| pca_df = pd.DataFrame(pca_result, columns=['PC1', 'PC2', 'PC3'], index=data_dict[omic].index) | |
| pca_df['Status'] = labels | |
| explained_variance = pca.explained_variance_ratio_ | |
| fig = px.scatter_3d( | |
| pca_df, x='PC1', y='PC2', z='PC3', color='Status', | |
| title=f'Projection PCA 3D ({omic}) - Variance : {explained_variance.sum():.2%}', | |
| labels={'PC1': f'PC1 ({explained_variance[0]:.2%})', 'PC2': f'PC2 ({explained_variance[1]:.2%})', 'PC3': f'PC3 ({explained_variance[2]:.2%})'}, | |
| color_continuous_scale='Viridis', | |
| opacity=0.7, | |
| template='plotly_dark' | |
| ) | |
| fig.update_traces(marker=dict(size=5)) | |
| st.plotly_chart(fig, use_container_width=True) | |
| # Clustering | |
| elif page == "Clustering": | |
| st.header("Clustering") | |
| if 'data_dict' not in st.session_state: | |
| st.warning("Chargez les données d'abord.") | |
| else: | |
| n_clusters = st.slider("Nombre de clusters", 2, 10, 5) | |
| if st.button("Lancer le Clustering"): | |
| data_dict = st.session_state['data_dict'] | |
| labels = st.session_state['labels'] | |
| common_samples = st.session_state['common_samples'] | |
| combined_data = pd.concat([data_dict[omic] for omic in data_dict], axis=1) | |
| kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10) | |
| umap_reducer = umap.UMAP(n_components=2, n_neighbors=15, min_dist=0.1, random_state=42) | |
| umap_embedding = umap_reducer.fit_transform(combined_data) | |
| umap_df = pd.DataFrame(umap_embedding, columns=['UMAP1', 'UMAP2'], index=common_samples) | |
| umap_df['Cluster'] = kmeans.fit_predict(umap_embedding) | |
| umap_df['Status'] = st.session_state['label_encoder'].inverse_transform(labels) | |
| st.session_state['umap_df'] = umap_df | |
| st.session_state['kmeans'] = kmeans | |
| st.session_state['umap_embedding'] = umap_embedding | |
| fig = px.scatter( | |
| umap_df, x='UMAP1', y='UMAP2', color='Cluster', symbol='Status', | |
| title='Projection UMAP avec Clusters KMeans', | |
| color_continuous_scale='Viridis', | |
| labels={'Cluster': 'Cluster', 'Status': 'Status'}, | |
| template='plotly_dark' | |
| ) | |
| fig.update_traces(marker=dict(size=10)) | |
| st.plotly_chart(fig, use_container_width=True) | |
| # Scores de Risque | |
| elif page == "Scores de Risque": | |
| st.header("Scores de Risque") | |
| if 'umap_df' not in st.session_state: | |
| st.warning("Effectuez le clustering d'abord.") | |
| else: | |
| umap_df = st.session_state['umap_df'] | |
| kmeans = st.session_state['kmeans'] | |
| umap_embedding = st.session_state['umap_embedding'] | |
| data_dict = st.session_state['data_dict'] | |
| labels = st.session_state['labels'] | |
| label_encoder = st.session_state['label_encoder'] | |
| if st.button("Calculer les Scores"): | |
| cluster_centers = kmeans.cluster_centers_ | |
| distances = np.zeros(len(umap_embedding)) | |
| for i, emb in enumerate(umap_embedding): | |
| distances[i] = np.linalg.norm(emb - cluster_centers[umap_df['Cluster'].iloc[i]]) | |
| base_risk = (distances - distances.min()) / (distances.max() - distances.min()) | |
| biomarker_weights = {omic: {col: 2.0 if col in irc_biomarkers else 1.0 for col in data_dict[omic].columns} for omic in data_dict} | |
| weighted_risk = base_risk.copy() | |
| for omic in data_dict: | |
| for col, weight in biomarker_weights[omic].items(): | |
| deviation = np.abs(data_dict[omic][col].values - data_dict[omic][col].mean()) | |
| weighted_risk += weight * deviation | |
| weighted_risk = (weighted_risk - weighted_risk.min()) / (weighted_risk.max() - weighted_risk.min()) | |
| classifier = RandomForestClassifier(n_estimators=100, random_state=42, class_weight='balanced') | |
| combined_data = pd.concat([data_dict[omic] for omic in data_dict], axis=1) | |
| classifier.fit(combined_data, labels) | |
| irc_class = label_encoder.transform(['IRC'])[0] if 'IRC' in label_encoder.classes_ else np.argmax(np.bincount(labels)) | |
| class_probs = classifier.predict_proba(combined_data)[:, irc_class] | |
| final_risk = weighted_risk * 0.5 + class_probs * 0.5 | |
| final_risk = final_risk * 100 | |
| umap_df['Score de Risque (%)'] = final_risk | |
| st.session_state['umap_df'] = umap_df | |
| fig = px.scatter( | |
| umap_df, x='UMAP1', y='UMAP2', color='Score de Risque (%)', symbol='Status', | |
| title='Projection UMAP avec Scores de Risque', | |
| color_continuous_scale='RdYlGn_r', | |
| labels={'Score de Risque (%)': 'Score de Risque (%)', 'Status': 'Status'}, | |
| template='plotly_dark' | |
| ) | |
| fig.update_traces(marker=dict(size=12)) | |
| st.plotly_chart(fig, use_container_width=True) | |
| # Analyse SHAP | |
| elif page == "Analyse SHAP": | |
| st.header("Analyse SHAP") | |
| if 'model' not in st.session_state: | |
| st.warning("Chargez les données d'abord.") | |
| else: | |
| model = st.session_state['model'] | |
| data_dict = st.session_state['data_dict'] | |
| input_dims = st.session_state['input_dims'] | |
| combined_data = pd.concat([data_dict[omic] for omic in data_dict], axis=1) | |
| feature_names = sum([data_dict[omic].columns.tolist() for omic in data_dict], []) | |
| if st.button("Lancer SHAP"): | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| X_concat = combined_data.values | |
| n_samples = min(100, X_concat.shape[0]) | |
| X_subset = X_concat[:n_samples] | |
| class VAEWrapper: | |
| def __init__(self, model, device): | |
| self.model = model | |
| self.device = device | |
| def predict(self, X): | |
| X_tensors = [] | |
| start = 0 | |
| for dim in input_dims: | |
| X_tensors.append(torch.tensor(X[:, start:start + dim], dtype=torch.float32).to(self.device)) | |
| start += dim | |
| with torch.no_grad(): | |
| _, z, _, _, _ = self.model(X_tensors) | |
| return torch.norm(z, dim=1).cpu().numpy() | |
| explainer = shap.KernelExplainer(VAEWrapper(model, device).predict, X_subset) | |
| shap_values = explainer.shap_values(X_subset, nsamples=100) | |
| shap_importance = pd.DataFrame({ | |
| 'Biomarqueur': feature_names[:len(np.mean(np.abs(shap_values[0]), axis=0))], | |
| 'Importance SHAP': np.mean(np.abs(shap_values[0]), axis=0) | |
| }).sort_values('Importance SHAP', ascending=False) | |
| st.session_state['shap_importance'] = shap_importance | |
| fig, ax = plt.subplots(figsize=(6, 4)) | |
| sns.barplot(data=shap_importance.head(10), x='Importance SHAP', y='Biomarqueur', palette='Set2') | |
| plt.title('Top 10 Biomarqueurs') | |
| st.pyplot(fig) | |
| plt.close() | |
| st.subheader("Biomarqueurs Clés") | |
| st.dataframe(shap_importance.head(10)) | |
| # Conseiller Médical | |
| elif page == "Conseiller Médical": | |
| st.header("Conseiller Médical BioBERT") | |
| st.markdown("Entrez les informations du patient pour des recommandations personnalisées et un rapport.") | |
| if 'umap_df' not in st.session_state or 'data_dict' not in st.session_state: | |
| st.warning("Calculez les scores de risque et chargez les données d'abord.") | |
| else: | |
| umap_df = st.session_state['umap_df'] | |
| data_dict = st.session_state['data_dict'] | |
| with st.form("patient_form"): | |
| patient_id = st.text_input("ID du Patient", help="Ex. Patient_001") | |
| age = st.number_input("Âge", min_value=18, max_value=120, value=30) | |
| sex = st.selectbox("Sexe", ["Homme", "Femme"]) | |
| family_history_irc = st.checkbox("Antécédents familiaux d’IRC") | |
| family_history_diabetes = st.checkbox("Antécédents familiaux de diabète") | |
| family_history_hypertension = st.checkbox("Antécédents familiaux d’hypertension") | |
| diabetes = st.checkbox("Diabète actuel") | |
| hypertension = st.checkbox("Hypertension actuelle") | |
| submitted = st.form_submit_button("Obtenir Recommandations et Rapport") | |
| if submitted and patient_id in umap_df.index: | |
| patient_data = { | |
| 'risk_score': umap_df.loc[patient_id, 'Score de Risque (%)'], | |
| 'age': age, | |
| 'sex': sex, | |
| 'family_history_irc': family_history_irc, | |
| 'family_history_diabetes': family_history_diabetes, | |
| 'family_history_hypertension': family_history_hypertension, | |
| 'diabetes': diabetes, | |
| 'hypertension': hypertension | |
| } | |
| advice = generate_recommendation_with_biobert( | |
| patient_data, patient_id, irc_biomarkers, biobert_tokenizer, biobert_model, data_dict | |
| ) | |
| st.markdown(f"<div class='bot-message'>{advice}</div>", unsafe_allow_html=True) | |
| # Générer et proposer le rapport PDF | |
| shap_importance = st.session_state.get('shap_importance', None) | |
| pdf_buffer = generate_pdf_report(patient_id, patient_data, advice, umap_df, shap_importance) | |
| st.download_button( | |
| label="Télécharger Rapport PDF", | |
| data=pdf_buffer, | |
| file_name=f"rapport_irc_{patient_id}_{datetime.now().strftime('%Y%m%d')}.pdf", | |
| mime="application/pdf" | |
| ) | |
| elif submitted: | |
| st.error("ID du patient invalide.") | |
| # Résumé | |
| elif page == "Résumé": | |
| st.header("Résumé") | |
| if 'umap_df' not in st.session_state: | |
| st.warning("Complétez les étapes précédentes.") | |
| else: | |
| st.subheader("Scores de Risque") | |
| st.dataframe(st.session_state['umap_df'][['Cluster', 'Status', 'Score de Risque (%)']]) | |
| csv_buffer = io.StringIO() | |
| st.session_state['umap_df'].to_csv(csv_buffer, index=True) | |
| csv_buffer.seek(0) | |
| st.download_button( | |
| label="Télécharger Résultats (CSV)", | |
| data=csv_buffer.getvalue(), | |
| file_name=f"resultats_irc_{datetime.now().strftime('%Y%m%d')}.csv", | |
| mime="text/csv" | |
| ) |