File size: 6,003 Bytes
28a6ef1
 
551fd48
ecf5b01
28a6ef1
 
 
 
 
 
8162584
28a6ef1
 
 
 
47012b9
 
 
551fd48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28a6ef1
 
47012b9
 
551fd48
 
 
 
 
28a6ef1
 
551fd48
 
28a6ef1
551fd48
47012b9
 
551fd48
47012b9
551fd48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47012b9
551fd48
 
 
 
 
 
 
 
 
 
 
28a6ef1
 
 
 
 
 
47012b9
 
28a6ef1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import gradio as gr
import tensorflow as tf
from tensorflow.keras import layers, Model
import keras
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt

# Load the model
MODEL_PATH = "./models/aging_score_autoencoder_fixed.keras"

# Global variables
model = None
encoder_model = None
load_error = None


def build_model(input_dim=18241, latent_dim=32):
    """
    Rebuild the model architecture from scratch to avoid deserialization issues.
    Architecture matches the training notebook exactly.
    """
    inputs = layers.Input(shape=(input_dim,))
    
    # Encoder
    x = layers.Dense(512, activation="relu")(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.3)(x)
    x = layers.Dense(128, activation="relu")(x)
    latent = layers.Dense(latent_dim, name="latent")(x)
    
    # Decoder
    x = layers.Dense(128, activation="relu")(latent)
    x = layers.Dense(512, activation="relu")(x)
    reconstruction = layers.Dense(input_dim, name="reconstruction")(x)
    
    # Age prediction head
    age_pred = layers.Dense(1, name="age")(latent)
    
    model = Model(inputs=inputs, outputs=[reconstruction, age_pred])
    return model


def load_resources():
    global model, encoder_model, load_error
    load_error = None
    
    # Build model from architecture
    model = build_model()
    
    # Try to load weights from saved file
    if os.path.exists(MODEL_PATH):
        try:
            print(f"Loading weights from {MODEL_PATH}...")
            # Try to load as keras model first to extract weights
            try:
                saved_model = keras.saving.load_model(
                    MODEL_PATH,
                    compile=False,
                    safe_mode=False,
                )
                model.set_weights(saved_model.get_weights())
                print("Weights loaded successfully from .keras file.")
            except Exception:
                # Fallback: try loading as h5 or other format
                try:
                    model.load_weights(MODEL_PATH)
                    print("Weights loaded successfully.")
                except Exception as e:
                    load_error = f"Could not load weights: {e}"
                    print(f"Warning: {load_error}")
                    print("Model will run with random weights.")
        except Exception as e:
            load_error = f"Error loading weights: {e}"
            print(f"Warning: {load_error}")
            print("Model will run with random weights.")
    else:
        load_error = f"Model file not found at {MODEL_PATH}. Model will run with random weights."
        print(load_error)
    
    # Create encoder model
    try:
        latent_layer = model.get_layer("latent")
        encoder_model = Model(inputs=model.input, outputs=latent_layer.output)
        print("Encoder model created successfully.")
    except Exception as e:
        print(f"Warning: Could not create encoder model: {e}")


# Initial load
load_resources()

def predict_aging(input_file, chron_age):
    if model is None:
        if load_error:
            return f"Error: {load_error}", None, None
        return "Error: Model not found.", None, None
    
    try:
        # Load data
        if input_file.name.endswith('.csv'):
            df = pd.read_csv(input_file.name)
        else:
            df = pd.read_parquet(input_file.name)

        # Feature selection (Genes only)
        META_COLS = ["sample_id", "subject_id", "tissue", "sex", "age", "death_time", "estimated_age"]
        gene_cols = [c for c in df.columns if c not in META_COLS]
        X = df[gene_cols].values
        
        # Preprocessing: log1p + standard normalization
        X_scaled = np.log1p(X)
        X_scaled = (X_scaled - np.mean(X_scaled)) / (np.std(X_scaled) + 1e-8)
        
        # Inference
        # model.predict returns [reconstruction, age_prediction]
        _, age_pred = model.predict(X_scaled)
        biological_age = float(age_pred[0][0])
        
        # Latent Aging Score
        aging_score = "N/A"
        if encoder_model:
            latent_vector = encoder_model.predict(X_scaled)
            # Using the mean of the latent vector as a proxy for the 'Aging Score' intensity
            aging_score = float(np.mean(latent_vector[0]))

        # Interpretation
        rhythm = biological_age - chron_age
        status = "Vieillissement Accéléré ⚠️" if rhythm > 2 else "Vieillissement Ralenti ✅" if rhythm < -2 else "Vieillissement Normal 🆗"
        
        # Summary
        res_text = f"""
        ### Résultats d'Analyse
        - **Âge Chronologique :** {chron_age} ans
        - **Âge Biologique (Estimé) :** {biological_age:.2f} ans
        - **Score de Vieillissement (Latent) :** {aging_score:.4f}
        - **Statut :** {status}
        """
        
        # Plot
        fig, ax = plt.subplots(figsize=(6, 2))
        colors = ['#2ecc71', '#f1c40f', '#e74c3c']
        ax.barh(['Rythme'], [rhythm], color='#3498db')
        ax.axvline(0, color='black', linestyle='--')
        ax.set_title("Différentiel de Vieillissement (Bio - Chrono)")
        ax.set_xlim(-15, 15)
        
        return res_text, fig
        
    except Exception as e:
        return f"Erreur : {str(e)}", None

# Gradio Interface
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# 🧠 Aging Score Bio-Predictor")
    gr.Markdown("Analyse du rythme de vieillissement biologique via Autoencoder supervisé.")
    
    with gr.Row():
        with gr.Column():
            input_file = gr.File(label="Données Transcriptomiques (18k gènes)")
            chron_age = gr.Number(label="Âge Chronologique Réel", value=40)
            btn = gr.Button("Calculer l'Aging Score", variant="primary")
            
        with gr.Column():
            output_text = gr.Markdown()
            output_plot = gr.Plot()

    btn.click(fn=predict_aging, inputs=[input_file, chron_age], outputs=[output_text, output_plot])

demo.launch()