import gradio as gr import tensorflow as tf from tensorflow.keras import layers, Model import keras import numpy as np import pandas as pd import os import matplotlib.pyplot as plt # Load the model MODEL_PATH = "./models/aging_score_autoencoder_fixed.keras" # Global variables model = None encoder_model = None load_error = None def build_model(input_dim=18241, latent_dim=32): """ Rebuild the model architecture from scratch to avoid deserialization issues. Architecture matches the training notebook exactly. """ inputs = layers.Input(shape=(input_dim,)) # Encoder x = layers.Dense(512, activation="relu")(inputs) x = layers.BatchNormalization()(x) x = layers.Dropout(0.3)(x) x = layers.Dense(128, activation="relu")(x) latent = layers.Dense(latent_dim, name="latent")(x) # Decoder x = layers.Dense(128, activation="relu")(latent) x = layers.Dense(512, activation="relu")(x) reconstruction = layers.Dense(input_dim, name="reconstruction")(x) # Age prediction head age_pred = layers.Dense(1, name="age")(latent) model = Model(inputs=inputs, outputs=[reconstruction, age_pred]) return model def load_resources(): global model, encoder_model, load_error load_error = None # Build model from architecture model = build_model() # Try to load weights from saved file if os.path.exists(MODEL_PATH): try: print(f"Loading weights from {MODEL_PATH}...") # Try to load as keras model first to extract weights try: saved_model = keras.saving.load_model( MODEL_PATH, compile=False, safe_mode=False, ) model.set_weights(saved_model.get_weights()) print("Weights loaded successfully from .keras file.") except Exception: # Fallback: try loading as h5 or other format try: model.load_weights(MODEL_PATH) print("Weights loaded successfully.") except Exception as e: load_error = f"Could not load weights: {e}" print(f"Warning: {load_error}") print("Model will run with random weights.") except Exception as e: load_error = f"Error loading weights: {e}" print(f"Warning: {load_error}") print("Model will run with random weights.") else: load_error = f"Model file not found at {MODEL_PATH}. Model will run with random weights." print(load_error) # Create encoder model try: latent_layer = model.get_layer("latent") encoder_model = Model(inputs=model.input, outputs=latent_layer.output) print("Encoder model created successfully.") except Exception as e: print(f"Warning: Could not create encoder model: {e}") # Initial load load_resources() def predict_aging(input_file, chron_age): if model is None: if load_error: return f"Error: {load_error}", None, None return "Error: Model not found.", None, None try: # Load data if input_file.name.endswith('.csv'): df = pd.read_csv(input_file.name) else: df = pd.read_parquet(input_file.name) # Feature selection (Genes only) META_COLS = ["sample_id", "subject_id", "tissue", "sex", "age", "death_time", "estimated_age"] gene_cols = [c for c in df.columns if c not in META_COLS] X = df[gene_cols].values # Preprocessing: log1p + standard normalization X_scaled = np.log1p(X) X_scaled = (X_scaled - np.mean(X_scaled)) / (np.std(X_scaled) + 1e-8) # Inference # model.predict returns [reconstruction, age_prediction] _, age_pred = model.predict(X_scaled) biological_age = float(age_pred[0][0]) # Latent Aging Score aging_score = "N/A" if encoder_model: latent_vector = encoder_model.predict(X_scaled) # Using the mean of the latent vector as a proxy for the 'Aging Score' intensity aging_score = float(np.mean(latent_vector[0])) # Interpretation rhythm = biological_age - chron_age status = "Vieillissement Accéléré ⚠️" if rhythm > 2 else "Vieillissement Ralenti ✅" if rhythm < -2 else "Vieillissement Normal 🆗" # Summary res_text = f""" ### Résultats d'Analyse - **Âge Chronologique :** {chron_age} ans - **Âge Biologique (Estimé) :** {biological_age:.2f} ans - **Score de Vieillissement (Latent) :** {aging_score:.4f} - **Statut :** {status} """ # Plot fig, ax = plt.subplots(figsize=(6, 2)) colors = ['#2ecc71', '#f1c40f', '#e74c3c'] ax.barh(['Rythme'], [rhythm], color='#3498db') ax.axvline(0, color='black', linestyle='--') ax.set_title("Différentiel de Vieillissement (Bio - Chrono)") ax.set_xlim(-15, 15) return res_text, fig except Exception as e: return f"Erreur : {str(e)}", None # Gradio Interface with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("# 🧠 Aging Score Bio-Predictor") gr.Markdown("Analyse du rythme de vieillissement biologique via Autoencoder supervisé.") with gr.Row(): with gr.Column(): input_file = gr.File(label="Données Transcriptomiques (18k gènes)") chron_age = gr.Number(label="Âge Chronologique Réel", value=40) btn = gr.Button("Calculer l'Aging Score", variant="primary") with gr.Column(): output_text = gr.Markdown() output_plot = gr.Plot() btn.click(fn=predict_aging, inputs=[input_file, chron_age], outputs=[output_text, output_plot]) demo.launch()