JayLacoma commited on
Commit
8d06132
·
verified ·
1 Parent(s): 8b365af

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -0
app.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from nilearn import datasets
4
+ from nilearn.connectome import ConnectivityMeasure
5
+ from nilearn.maskers import MultiNiftiMapsMasker
6
+ import numpy as np
7
+
8
+ #device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+
10
+ # Force torch to use CPU only
11
+ device = torch.device("cpu")
12
+
13
+
14
+
15
+ try:
16
+ scripted_model = torch.jit.load("fmri_encoder_commercial.pt", map_location="cpu")
17
+
18
+ # If the model is wrapped in DataParallel, unwrap it
19
+ if isinstance(scripted_model, torch.nn.DataParallel):
20
+ scripted_model = scripted_model.module
21
+
22
+ scripted_model.to(device)
23
+ scripted_model.eval()
24
+ except Exception as e:
25
+ print(f"Error loading model: {str(e)}")
26
+ exit(1)
27
+
28
+ # Fetch atlas (e.g., DiFuMo)
29
+ dim = 64 # Number of ROIs
30
+ try:
31
+ difumo = datasets.fetch_atlas_difumo(dimension=dim, resolution_mm=2, legacy_format=False)
32
+ atlas_filename = difumo.maps
33
+ except Exception as e:
34
+ print(f"Error fetching atlas: {str(e)}")
35
+ exit(1)
36
+
37
+ # Create masker to extract features
38
+ masker = MultiNiftiMapsMasker(
39
+ maps_img=atlas_filename,
40
+ standardize=True,
41
+ n_jobs=-1,
42
+ verbose=0
43
+ )
44
+
45
+ # Connectivity measure
46
+ connectome_measure = ConnectivityMeasure(kind='correlation', vectorize=True, discard_diagonal=True)
47
+
48
+ # Modified feature extraction function
49
+ def extract_features_multiple(func_preproc_files):
50
+ all_features = []
51
+ if not func_preproc_files:
52
+ return all_features
53
+
54
+ # Fit the masker on the first subject
55
+ print("Fitting masker on the first subject...")
56
+ masker.fit(func_preproc_files[0])
57
+
58
+ for i, sub in enumerate(func_preproc_files):
59
+ print(f"Processing subject {i+1}...")
60
+ masked_data = masker.transform(sub)
61
+ transformed_data = connectome_measure.fit_transform([masked_data])[0]
62
+ all_features.append(transformed_data)
63
+
64
+ print("All subjects processed.")
65
+ return all_features
66
+
67
+ # Prediction function with error handling
68
+ def validate_inputs(features_tensor):
69
+ if features_tensor.shape[1] != 2016:
70
+ raise ValueError(f"Expected 2016 features but got {features_tensor.shape[1]}")
71
+ if features_tensor.dim() != 2:
72
+ raise ValueError(f"Expected 2D tensor but got {features_tensor.dim()}D")
73
+
74
+ def predict_autism(fmri_files, age, gender):
75
+ try:
76
+ if not fmri_files:
77
+ return "Please upload at least one valid .nii.gz file."
78
+
79
+ features_list = extract_features_multiple(fmri_files)
80
+ if not features_list:
81
+ return "Error: Failed to extract features from the fMRI files."
82
+
83
+ # CORRECTED TENSOR SHAPES
84
+ age_tensor = torch.tensor([float(age)], dtype=torch.float32).to(device) # Shape: [1]
85
+ gender_tensor = torch.tensor([int(gender)], dtype=torch.long).to(device) # Shape: [1]
86
+
87
+ predictions = []
88
+ for features in features_list:
89
+ features_tensor = torch.tensor(features, dtype=torch.float32).unsqueeze(0).to(device)
90
+ validate_inputs(features_tensor)
91
+
92
+ with torch.no_grad():
93
+ prediction = scripted_model(features_tensor, age_tensor, gender_tensor)
94
+ probability = torch.sigmoid(prediction).item()
95
+
96
+ result = f"Prediction: {'Autism' if probability > 0.5 else 'No Autism'} (Confidence: {probability:.2%})"
97
+ predictions.append(result)
98
+
99
+ return "\n".join(predictions)
100
+
101
+ except Exception as e:
102
+ return f"Error: {str(e)}"
103
+
104
+ # Gradio interface
105
+ iface = gr.Interface(
106
+ fn=predict_autism,
107
+ inputs=[
108
+ gr.File(label="Upload preprocessed fMRI files (.nii.gz)", file_count="multiple"),
109
+ gr.Number(label="Age", minimum=0, maximum=120),
110
+ gr.Radio(["0", "1"], label="Gender (0: Female, 1: Male)"),
111
+ ],
112
+ outputs=gr.Text(label="Prediction Result"),
113
+ title="Autism Prediction from fMRI Data",
114
+ description="Upload one or more preprocessed fMRI files (.nii.gz) and enter the subject's age and gender to predict autism.",
115
+ theme="default",
116
+ flagging_mode="never"
117
+ )
118
+
119
+ iface.launch()