Ghaithhmz's picture
merge
6ec966c
import gradio as gr
import tensorflow as tf
from tensorflow.keras import layers, Model
import keras
import numpy as np
import pandas as pd
import os
import matplotlib.pyplot as plt
# Load the model
MODEL_PATH = "./models/aging_score_autoencoder_fixed.keras"
# Global variables
model = None
encoder_model = None
load_error = None
def build_model(input_dim=18241, latent_dim=32):
"""
Rebuild the model architecture from scratch to avoid deserialization issues.
Architecture matches the training notebook exactly.
"""
inputs = layers.Input(shape=(input_dim,))
# Encoder
x = layers.Dense(512, activation="relu")(inputs)
x = layers.BatchNormalization()(x)
x = layers.Dropout(0.3)(x)
x = layers.Dense(128, activation="relu")(x)
latent = layers.Dense(latent_dim, name="latent")(x)
# Decoder
x = layers.Dense(128, activation="relu")(latent)
x = layers.Dense(512, activation="relu")(x)
reconstruction = layers.Dense(input_dim, name="reconstruction")(x)
# Age prediction head
age_pred = layers.Dense(1, name="age")(latent)
model = Model(inputs=inputs, outputs=[reconstruction, age_pred])
return model
def load_resources():
global model, encoder_model, load_error
load_error = None
# Build model from architecture
model = build_model()
# Try to load weights from saved file
if os.path.exists(MODEL_PATH):
try:
print(f"Loading weights from {MODEL_PATH}...")
# Try to load as keras model first to extract weights
try:
saved_model = keras.saving.load_model(
MODEL_PATH,
compile=False,
safe_mode=False,
)
model.set_weights(saved_model.get_weights())
print("Weights loaded successfully from .keras file.")
except Exception:
# Fallback: try loading as h5 or other format
try:
model.load_weights(MODEL_PATH)
print("Weights loaded successfully.")
except Exception as e:
load_error = f"Could not load weights: {e}"
print(f"Warning: {load_error}")
print("Model will run with random weights.")
except Exception as e:
load_error = f"Error loading weights: {e}"
print(f"Warning: {load_error}")
print("Model will run with random weights.")
else:
load_error = f"Model file not found at {MODEL_PATH}. Model will run with random weights."
print(load_error)
# Create encoder model
try:
latent_layer = model.get_layer("latent")
encoder_model = Model(inputs=model.input, outputs=latent_layer.output)
print("Encoder model created successfully.")
except Exception as e:
print(f"Warning: Could not create encoder model: {e}")
# Initial load
load_resources()
def predict_aging(input_file, chron_age):
if model is None:
if load_error:
return f"Error: {load_error}", None, None
return "Error: Model not found.", None, None
try:
# Load data
if input_file.name.endswith('.csv'):
df = pd.read_csv(input_file.name)
else:
df = pd.read_parquet(input_file.name)
# Feature selection (Genes only)
META_COLS = ["sample_id", "subject_id", "tissue", "sex", "age", "death_time", "estimated_age"]
gene_cols = [c for c in df.columns if c not in META_COLS]
X = df[gene_cols].values
# Preprocessing: log1p + standard normalization
X_scaled = np.log1p(X)
X_scaled = (X_scaled - np.mean(X_scaled)) / (np.std(X_scaled) + 1e-8)
# Inference
# model.predict returns [reconstruction, age_prediction]
_, age_pred = model.predict(X_scaled)
biological_age = float(age_pred[0][0])
# Latent Aging Score
aging_score = "N/A"
if encoder_model:
latent_vector = encoder_model.predict(X_scaled)
# Using the mean of the latent vector as a proxy for the 'Aging Score' intensity
aging_score = float(np.mean(latent_vector[0]))
# Interpretation
rhythm = biological_age - chron_age
status = "Vieillissement Accéléré ⚠️" if rhythm > 2 else "Vieillissement Ralenti ✅" if rhythm < -2 else "Vieillissement Normal 🆗"
# Summary
res_text = f"""
### Résultats d'Analyse
- **Âge Chronologique :** {chron_age} ans
- **Âge Biologique (Estimé) :** {biological_age:.2f} ans
- **Score de Vieillissement (Latent) :** {aging_score:.4f}
- **Statut :** {status}
"""
# Plot
fig, ax = plt.subplots(figsize=(6, 2))
colors = ['#2ecc71', '#f1c40f', '#e74c3c']
ax.barh(['Rythme'], [rhythm], color='#3498db')
ax.axvline(0, color='black', linestyle='--')
ax.set_title("Différentiel de Vieillissement (Bio - Chrono)")
ax.set_xlim(-15, 15)
return res_text, fig
except Exception as e:
return f"Erreur : {str(e)}", None
# Gradio Interface
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🧠 Aging Score Bio-Predictor")
gr.Markdown("Analyse du rythme de vieillissement biologique via Autoencoder supervisé.")
with gr.Row():
with gr.Column():
input_file = gr.File(label="Données Transcriptomiques (18k gènes)")
chron_age = gr.Number(label="Âge Chronologique Réel", value=40)
btn = gr.Button("Calculer l'Aging Score", variant="primary")
with gr.Column():
output_text = gr.Markdown()
output_plot = gr.Plot()
btn.click(fn=predict_aging, inputs=[input_file, chron_age], outputs=[output_text, output_plot])
demo.launch()