JayLacoma commited on
Commit
876bb97
·
verified ·
1 Parent(s): 8d06132

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -30
app.py CHANGED
@@ -1,21 +1,17 @@
1
  import torch
2
  import gradio as gr
 
3
  from nilearn import datasets
4
  from nilearn.connectome import ConnectivityMeasure
5
  from nilearn.maskers import MultiNiftiMapsMasker
6
  import numpy as np
7
 
8
- #device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
-
10
- # Force torch to use CPU only
11
- device = torch.device("cpu")
12
-
13
-
14
 
 
15
  try:
16
- scripted_model = torch.jit.load("fmri_encoder_commercial.pt", map_location="cpu")
17
 
18
- # If the model is wrapped in DataParallel, unwrap it
19
  if isinstance(scripted_model, torch.nn.DataParallel):
20
  scripted_model = scripted_model.module
21
 
@@ -26,7 +22,7 @@ except Exception as e:
26
  exit(1)
27
 
28
  # Fetch atlas (e.g., DiFuMo)
29
- dim = 64 # Number of ROIs
30
  try:
31
  difumo = datasets.fetch_atlas_difumo(dimension=dim, resolution_mm=2, legacy_format=False)
32
  atlas_filename = difumo.maps
@@ -34,7 +30,7 @@ except Exception as e:
34
  print(f"Error fetching atlas: {str(e)}")
35
  exit(1)
36
 
37
- # Create masker to extract features
38
  masker = MultiNiftiMapsMasker(
39
  maps_img=atlas_filename,
40
  standardize=True,
@@ -45,13 +41,12 @@ masker = MultiNiftiMapsMasker(
45
  # Connectivity measure
46
  connectome_measure = ConnectivityMeasure(kind='correlation', vectorize=True, discard_diagonal=True)
47
 
48
- # Modified feature extraction function
49
  def extract_features_multiple(func_preproc_files):
50
  all_features = []
51
  if not func_preproc_files:
52
  return all_features
53
 
54
- # Fit the masker on the first subject
55
  print("Fitting masker on the first subject...")
56
  masker.fit(func_preproc_files[0])
57
 
@@ -64,42 +59,66 @@ def extract_features_multiple(func_preproc_files):
64
  print("All subjects processed.")
65
  return all_features
66
 
67
- # Prediction function with error handling
68
- def validate_inputs(features_tensor):
69
- if features_tensor.shape[1] != 2016:
70
- raise ValueError(f"Expected 2016 features but got {features_tensor.shape[1]}")
71
- if features_tensor.dim() != 2:
72
- raise ValueError(f"Expected 2D tensor but got {features_tensor.dim()}D")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
 
74
  def predict_autism(fmri_files, age, gender):
75
  try:
76
  if not fmri_files:
77
- return "Please upload at least one valid .nii.gz file."
78
 
79
  features_list = extract_features_multiple(fmri_files)
80
  if not features_list:
81
- return "Error: Failed to extract features from the fMRI files."
82
 
83
- # CORRECTED TENSOR SHAPES
84
- age_tensor = torch.tensor([float(age)], dtype=torch.float32).to(device) # Shape: [1]
85
- gender_tensor = torch.tensor([int(gender)], dtype=torch.long).to(device) # Shape: [1]
86
 
87
  predictions = []
 
 
88
  for features in features_list:
89
  features_tensor = torch.tensor(features, dtype=torch.float32).unsqueeze(0).to(device)
90
- validate_inputs(features_tensor)
91
-
92
  with torch.no_grad():
93
  prediction = scripted_model(features_tensor, age_tensor, gender_tensor)
94
  probability = torch.sigmoid(prediction).item()
95
 
96
  result = f"Prediction: {'Autism' if probability > 0.5 else 'No Autism'} (Confidence: {probability:.2%})"
97
  predictions.append(result)
 
 
 
98
 
99
- return "\n".join(predictions)
100
-
101
  except Exception as e:
102
- return f"Error: {str(e)}"
103
 
104
  # Gradio interface
105
  iface = gr.Interface(
@@ -109,11 +128,14 @@ iface = gr.Interface(
109
  gr.Number(label="Age", minimum=0, maximum=120),
110
  gr.Radio(["0", "1"], label="Gender (0: Female, 1: Male)"),
111
  ],
112
- outputs=gr.Text(label="Prediction Result"),
 
 
 
113
  title="Autism Prediction from fMRI Data",
114
  description="Upload one or more preprocessed fMRI files (.nii.gz) and enter the subject's age and gender to predict autism.",
115
  theme="default",
116
  flagging_mode="never"
117
  )
118
 
119
- iface.launch()
 
1
  import torch
2
  import gradio as gr
3
+ import plotly.graph_objects as go
4
  from nilearn import datasets
5
  from nilearn.connectome import ConnectivityMeasure
6
  from nilearn.maskers import MultiNiftiMapsMasker
7
  import numpy as np
8
 
9
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
10
 
11
+ # Load the model
12
  try:
13
+ scripted_model = torch.jit.load("fmri_encoder_commercial.pt", map_location=device)
14
 
 
15
  if isinstance(scripted_model, torch.nn.DataParallel):
16
  scripted_model = scripted_model.module
17
 
 
22
  exit(1)
23
 
24
  # Fetch atlas (e.g., DiFuMo)
25
+ dim = 64
26
  try:
27
  difumo = datasets.fetch_atlas_difumo(dimension=dim, resolution_mm=2, legacy_format=False)
28
  atlas_filename = difumo.maps
 
30
  print(f"Error fetching atlas: {str(e)}")
31
  exit(1)
32
 
33
+ # Create masker
34
  masker = MultiNiftiMapsMasker(
35
  maps_img=atlas_filename,
36
  standardize=True,
 
41
  # Connectivity measure
42
  connectome_measure = ConnectivityMeasure(kind='correlation', vectorize=True, discard_diagonal=True)
43
 
44
+ # Feature extraction function
45
  def extract_features_multiple(func_preproc_files):
46
  all_features = []
47
  if not func_preproc_files:
48
  return all_features
49
 
 
50
  print("Fitting masker on the first subject...")
51
  masker.fit(func_preproc_files[0])
52
 
 
59
  print("All subjects processed.")
60
  return all_features
61
 
62
+ # Function to generate a Plotly probability plot
63
+ def plot_probability(probability):
64
+ labels = ["No Autism", "Autism"]
65
+ probs = [1 - probability, probability]
66
+ colors = ["#6a0dad", "#d896ff"] # Dark purple and light purple
67
+
68
+ fig = go.Figure()
69
+
70
+ fig.add_trace(go.Bar(
71
+ x=labels,
72
+ y=probs,
73
+ marker=dict(color=colors),
74
+ text=[f"{(1-probability)*100:.1f}%", f"{probability*100:.1f}%"],
75
+ textposition="auto",
76
+ ))
77
+
78
+ fig.update_layout(
79
+ title="Autism Prediction Probability",
80
+ paper_bgcolor="black",
81
+ plot_bgcolor="black",
82
+ font=dict(color="white"),
83
+ xaxis=dict(title="Diagnosis", showgrid=False),
84
+ yaxis=dict(title="Probability", showgrid=True, gridcolor="gray"),
85
+ )
86
+
87
+ return fig
88
 
89
+ # Prediction function
90
  def predict_autism(fmri_files, age, gender):
91
  try:
92
  if not fmri_files:
93
+ return "Please upload at least one valid .nii.gz file.", None
94
 
95
  features_list = extract_features_multiple(fmri_files)
96
  if not features_list:
97
+ return "Error: Failed to extract features from the fMRI files.", None
98
 
99
+ age_tensor = torch.tensor([float(age)], dtype=torch.float32).to(device)
100
+ gender_tensor = torch.tensor([int(gender)], dtype=torch.long).to(device)
 
101
 
102
  predictions = []
103
+ plots = []
104
+
105
  for features in features_list:
106
  features_tensor = torch.tensor(features, dtype=torch.float32).unsqueeze(0).to(device)
107
+
 
108
  with torch.no_grad():
109
  prediction = scripted_model(features_tensor, age_tensor, gender_tensor)
110
  probability = torch.sigmoid(prediction).item()
111
 
112
  result = f"Prediction: {'Autism' if probability > 0.5 else 'No Autism'} (Confidence: {probability:.2%})"
113
  predictions.append(result)
114
+
115
+ # Generate Plotly probability plot
116
+ plots.append(plot_probability(probability))
117
 
118
+ return "\n".join(predictions), plots[0] # Return text and Plotly figure
119
+
120
  except Exception as e:
121
+ return f"Error: {str(e)}", None
122
 
123
  # Gradio interface
124
  iface = gr.Interface(
 
128
  gr.Number(label="Age", minimum=0, maximum=120),
129
  gr.Radio(["0", "1"], label="Gender (0: Female, 1: Male)"),
130
  ],
131
+ outputs=[
132
+ gr.Text(label="Prediction Result"),
133
+ gr.Plot(label="Prediction Probability Plot"),
134
+ ],
135
  title="Autism Prediction from fMRI Data",
136
  description="Upload one or more preprocessed fMRI files (.nii.gz) and enter the subject's age and gender to predict autism.",
137
  theme="default",
138
  flagging_mode="never"
139
  )
140
 
141
+ iface.launch()