Spaces:
Sleeping
Sleeping
File size: 4,544 Bytes
8d06132 876bb97 8d06132 1ebe992 8d06132 876bb97 8d06132 876bb97 8d06132 876bb97 8d06132 876bb97 8d06132 876bb97 8d06132 876bb97 8d06132 876bb97 8d06132 876bb97 8d06132 876bb97 8d06132 876bb97 8d06132 876bb97 8d06132 876bb97 8d06132 876bb97 8d06132 876bb97 8d06132 876bb97 8d06132 876bb97 8d06132 876bb97 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import torch
import gradio as gr
import plotly.graph_objects as go
from nilearn import datasets
from nilearn.connectome import ConnectivityMeasure
from nilearn.maskers import MultiNiftiMapsMasker
import numpy as np
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")
# Load the model
try:
scripted_model = torch.jit.load("fmri_encoder_commercial.pt", map_location=device)
if isinstance(scripted_model, torch.nn.DataParallel):
scripted_model = scripted_model.module
scripted_model.to(device)
scripted_model.eval()
except Exception as e:
print(f"Error loading model: {str(e)}")
exit(1)
# Fetch atlas (e.g., DiFuMo)
dim = 64
try:
difumo = datasets.fetch_atlas_difumo(dimension=dim, resolution_mm=2, legacy_format=False)
atlas_filename = difumo.maps
except Exception as e:
print(f"Error fetching atlas: {str(e)}")
exit(1)
# Create masker
masker = MultiNiftiMapsMasker(
maps_img=atlas_filename,
standardize=True,
n_jobs=-1,
verbose=0
)
# Connectivity measure
connectome_measure = ConnectivityMeasure(kind='correlation', vectorize=True, discard_diagonal=True)
# Feature extraction function
def extract_features_multiple(func_preproc_files):
all_features = []
if not func_preproc_files:
return all_features
print("Fitting masker on the first subject...")
masker.fit(func_preproc_files[0])
for i, sub in enumerate(func_preproc_files):
print(f"Processing subject {i+1}...")
masked_data = masker.transform(sub)
transformed_data = connectome_measure.fit_transform([masked_data])[0]
all_features.append(transformed_data)
print("All subjects processed.")
return all_features
# Function to generate a Plotly probability plot
def plot_probability(probability):
labels = ["No Autism", "Autism"]
probs = [1 - probability, probability]
colors = ["#6a0dad", "#d896ff"] # Dark purple and light purple
fig = go.Figure()
fig.add_trace(go.Bar(
x=labels,
y=probs,
marker=dict(color=colors),
text=[f"{(1-probability)*100:.1f}%", f"{probability*100:.1f}%"],
textposition="auto",
))
fig.update_layout(
title="Autism Prediction Probability",
paper_bgcolor="black",
plot_bgcolor="black",
font=dict(color="white"),
xaxis=dict(title="Diagnosis", showgrid=False),
yaxis=dict(title="Probability", showgrid=True, gridcolor="gray"),
)
return fig
# Prediction function
def predict_autism(fmri_files, age, gender):
try:
if not fmri_files:
return "Please upload at least one valid .nii.gz file.", None
features_list = extract_features_multiple(fmri_files)
if not features_list:
return "Error: Failed to extract features from the fMRI files.", None
age_tensor = torch.tensor([float(age)], dtype=torch.float32).to(device)
gender_tensor = torch.tensor([int(gender)], dtype=torch.long).to(device)
predictions = []
plots = []
for features in features_list:
features_tensor = torch.tensor(features, dtype=torch.float32).unsqueeze(0).to(device)
with torch.no_grad():
prediction = scripted_model(features_tensor, age_tensor, gender_tensor)
probability = torch.sigmoid(prediction).item()
result = f"Prediction: {'Autism' if probability > 0.5 else 'No Autism'} (Confidence: {probability:.2%})"
predictions.append(result)
# Generate Plotly probability plot
plots.append(plot_probability(probability))
return "\n".join(predictions), plots[0] # Return text and Plotly figure
except Exception as e:
return f"Error: {str(e)}", None
# Gradio interface
iface = gr.Interface(
fn=predict_autism,
inputs=[
gr.File(label="Upload preprocessed fMRI files (.nii.gz)", file_count="multiple"),
gr.Number(label="Age", minimum=0, maximum=120),
gr.Radio(["0", "1"], label="Gender (0: Female, 1: Male)"),
],
outputs=[
gr.Text(label="Prediction Result"),
gr.Plot(label="Prediction Probability Plot"),
],
title="Autism Prediction from fMRI Data",
description="Upload one or more preprocessed fMRI files (.nii.gz) and enter the subject's age and gender to predict autism.",
theme="default",
flagging_mode="never"
)
iface.launch()
|