File size: 4,544 Bytes
8d06132
 
876bb97
8d06132
 
 
 
 
1ebe992
 
8d06132
876bb97
8d06132
876bb97
8d06132
 
 
 
 
 
 
 
 
 
 
876bb97
8d06132
 
 
 
 
 
 
876bb97
8d06132
 
 
 
 
 
 
 
 
 
876bb97
8d06132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
876bb97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8d06132
876bb97
8d06132
 
 
876bb97
8d06132
 
 
876bb97
8d06132
876bb97
 
8d06132
 
876bb97
 
8d06132
 
876bb97
8d06132
 
 
 
 
 
876bb97
 
 
8d06132
876bb97
 
8d06132
876bb97
8d06132
 
 
 
 
 
 
 
 
876bb97
 
 
 
8d06132
 
 
 
 
 
876bb97
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import torch
import gradio as gr
import plotly.graph_objects as go
from nilearn import datasets
from nilearn.connectome import ConnectivityMeasure
from nilearn.maskers import MultiNiftiMapsMasker
import numpy as np

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")

# Load the model
try:
    scripted_model = torch.jit.load("fmri_encoder_commercial.pt", map_location=device)
    
    if isinstance(scripted_model, torch.nn.DataParallel):
        scripted_model = scripted_model.module
    
    scripted_model.to(device)
    scripted_model.eval()
except Exception as e:
    print(f"Error loading model: {str(e)}")
    exit(1)

# Fetch atlas (e.g., DiFuMo)
dim = 64  
try:
    difumo = datasets.fetch_atlas_difumo(dimension=dim, resolution_mm=2, legacy_format=False)
    atlas_filename = difumo.maps
except Exception as e:
    print(f"Error fetching atlas: {str(e)}")
    exit(1)

# Create masker
masker = MultiNiftiMapsMasker(
    maps_img=atlas_filename,
    standardize=True,
    n_jobs=-1,
    verbose=0
)

# Connectivity measure
connectome_measure = ConnectivityMeasure(kind='correlation', vectorize=True, discard_diagonal=True)

# Feature extraction function
def extract_features_multiple(func_preproc_files):
    all_features = []
    if not func_preproc_files:
        return all_features
    
    print("Fitting masker on the first subject...")
    masker.fit(func_preproc_files[0])
    
    for i, sub in enumerate(func_preproc_files):
        print(f"Processing subject {i+1}...")
        masked_data = masker.transform(sub)
        transformed_data = connectome_measure.fit_transform([masked_data])[0]
        all_features.append(transformed_data)
    
    print("All subjects processed.")
    return all_features

# Function to generate a Plotly probability plot
def plot_probability(probability):
    labels = ["No Autism", "Autism"]
    probs = [1 - probability, probability]
    colors = ["#6a0dad", "#d896ff"]  # Dark purple and light purple

    fig = go.Figure()

    fig.add_trace(go.Bar(
        x=labels,
        y=probs,
        marker=dict(color=colors),
        text=[f"{(1-probability)*100:.1f}%", f"{probability*100:.1f}%"],
        textposition="auto",
    ))

    fig.update_layout(
        title="Autism Prediction Probability",
        paper_bgcolor="black",
        plot_bgcolor="black",
        font=dict(color="white"),
        xaxis=dict(title="Diagnosis", showgrid=False),
        yaxis=dict(title="Probability", showgrid=True, gridcolor="gray"),
    )

    return fig

# Prediction function
def predict_autism(fmri_files, age, gender):
    try:
        if not fmri_files:
            return "Please upload at least one valid .nii.gz file.", None
        
        features_list = extract_features_multiple(fmri_files)
        if not features_list:
            return "Error: Failed to extract features from the fMRI files.", None
        
        age_tensor = torch.tensor([float(age)], dtype=torch.float32).to(device)
        gender_tensor = torch.tensor([int(gender)], dtype=torch.long).to(device)

        predictions = []
        plots = []

        for features in features_list:
            features_tensor = torch.tensor(features, dtype=torch.float32).unsqueeze(0).to(device)

            with torch.no_grad():
                prediction = scripted_model(features_tensor, age_tensor, gender_tensor)
                probability = torch.sigmoid(prediction).item()
            
            result = f"Prediction: {'Autism' if probability > 0.5 else 'No Autism'} (Confidence: {probability:.2%})"
            predictions.append(result)

            # Generate Plotly probability plot
            plots.append(plot_probability(probability))
        
        return "\n".join(predictions), plots[0]  # Return text and Plotly figure

    except Exception as e:
        return f"Error: {str(e)}", None

# Gradio interface
iface = gr.Interface(
    fn=predict_autism,
    inputs=[
        gr.File(label="Upload preprocessed fMRI files (.nii.gz)", file_count="multiple"),
        gr.Number(label="Age", minimum=0, maximum=120),
        gr.Radio(["0", "1"], label="Gender (0: Female, 1: Male)"),
    ],
    outputs=[
        gr.Text(label="Prediction Result"),
        gr.Plot(label="Prediction Probability Plot"),
    ],
    title="Autism Prediction from fMRI Data",
    description="Upload one or more preprocessed fMRI files (.nii.gz) and enter the subject's age and gender to predict autism.",
    theme="default",
    flagging_mode="never"
)

iface.launch()