JayLacoma's picture
Update app.py
1ebe992 verified
import torch
import gradio as gr
import plotly.graph_objects as go
from nilearn import datasets
from nilearn.connectome import ConnectivityMeasure
from nilearn.maskers import MultiNiftiMapsMasker
import numpy as np
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")
# Load the model
try:
scripted_model = torch.jit.load("fmri_encoder_commercial.pt", map_location=device)
if isinstance(scripted_model, torch.nn.DataParallel):
scripted_model = scripted_model.module
scripted_model.to(device)
scripted_model.eval()
except Exception as e:
print(f"Error loading model: {str(e)}")
exit(1)
# Fetch atlas (e.g., DiFuMo)
dim = 64
try:
difumo = datasets.fetch_atlas_difumo(dimension=dim, resolution_mm=2, legacy_format=False)
atlas_filename = difumo.maps
except Exception as e:
print(f"Error fetching atlas: {str(e)}")
exit(1)
# Create masker
masker = MultiNiftiMapsMasker(
maps_img=atlas_filename,
standardize=True,
n_jobs=-1,
verbose=0
)
# Connectivity measure
connectome_measure = ConnectivityMeasure(kind='correlation', vectorize=True, discard_diagonal=True)
# Feature extraction function
def extract_features_multiple(func_preproc_files):
all_features = []
if not func_preproc_files:
return all_features
print("Fitting masker on the first subject...")
masker.fit(func_preproc_files[0])
for i, sub in enumerate(func_preproc_files):
print(f"Processing subject {i+1}...")
masked_data = masker.transform(sub)
transformed_data = connectome_measure.fit_transform([masked_data])[0]
all_features.append(transformed_data)
print("All subjects processed.")
return all_features
# Function to generate a Plotly probability plot
def plot_probability(probability):
labels = ["No Autism", "Autism"]
probs = [1 - probability, probability]
colors = ["#6a0dad", "#d896ff"] # Dark purple and light purple
fig = go.Figure()
fig.add_trace(go.Bar(
x=labels,
y=probs,
marker=dict(color=colors),
text=[f"{(1-probability)*100:.1f}%", f"{probability*100:.1f}%"],
textposition="auto",
))
fig.update_layout(
title="Autism Prediction Probability",
paper_bgcolor="black",
plot_bgcolor="black",
font=dict(color="white"),
xaxis=dict(title="Diagnosis", showgrid=False),
yaxis=dict(title="Probability", showgrid=True, gridcolor="gray"),
)
return fig
# Prediction function
def predict_autism(fmri_files, age, gender):
try:
if not fmri_files:
return "Please upload at least one valid .nii.gz file.", None
features_list = extract_features_multiple(fmri_files)
if not features_list:
return "Error: Failed to extract features from the fMRI files.", None
age_tensor = torch.tensor([float(age)], dtype=torch.float32).to(device)
gender_tensor = torch.tensor([int(gender)], dtype=torch.long).to(device)
predictions = []
plots = []
for features in features_list:
features_tensor = torch.tensor(features, dtype=torch.float32).unsqueeze(0).to(device)
with torch.no_grad():
prediction = scripted_model(features_tensor, age_tensor, gender_tensor)
probability = torch.sigmoid(prediction).item()
result = f"Prediction: {'Autism' if probability > 0.5 else 'No Autism'} (Confidence: {probability:.2%})"
predictions.append(result)
# Generate Plotly probability plot
plots.append(plot_probability(probability))
return "\n".join(predictions), plots[0] # Return text and Plotly figure
except Exception as e:
return f"Error: {str(e)}", None
# Gradio interface
iface = gr.Interface(
fn=predict_autism,
inputs=[
gr.File(label="Upload preprocessed fMRI files (.nii.gz)", file_count="multiple"),
gr.Number(label="Age", minimum=0, maximum=120),
gr.Radio(["0", "1"], label="Gender (0: Female, 1: Male)"),
],
outputs=[
gr.Text(label="Prediction Result"),
gr.Plot(label="Prediction Probability Plot"),
],
title="Autism Prediction from fMRI Data",
description="Upload one or more preprocessed fMRI files (.nii.gz) and enter the subject's age and gender to predict autism.",
theme="default",
flagging_mode="never"
)
iface.launch()