ShahdReda commited on
Commit
bb92fcc
·
verified ·
1 Parent(s): 2f18ba9

Upload 4 files

Browse files
Files changed (4) hide show
  1. aal_mask_pad.nii.gz +3 -0
  2. app.py +131 -0
  3. requirements.txt +5 -0
  4. svm_pipeline.pkl +3 -0
aal_mask_pad.nii.gz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:230ad628e55f722a2ebaba02bebd913e7d523ab3725fc8aa2891f052ec9ec43f
3
+ size 12050
app.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import nibabel as nib
3
+ import numpy as np
4
+ import os
5
+ import shutil
6
+ import pickle
7
+ import pandas as pd
8
+
9
+ # Function to load the model from the pickle file
10
+ def load_model():
11
+ with open('svm_pipeline.pkl', 'rb') as f:
12
+ return pickle.load(f)
13
+
14
+ # Load the trained model
15
+ model = load_model()
16
+
17
+ # Function to load image data from a filepath
18
+ def get_image_data(filepath):
19
+ '''
20
+ Access the floating point data of an image
21
+
22
+ Input: Filepath to the image
23
+
24
+ Output: The image's floating point data
25
+ '''
26
+ img = nib.load(filepath)
27
+ data = img.get_fdata()
28
+ return data
29
+
30
+ # Function to create a vector from a region by time matrix from an image using the atlas
31
+ def image_to_vector(image_data, atlas_data):
32
+ '''
33
+ Create a vector from a region by time matrix from an image using the atlas
34
+
35
+ Input:
36
+ - Data for the image to take points of
37
+ - Data from the atlas to apply to the image data
38
+
39
+ Output: A vector of the image's region by time matrix
40
+ '''
41
+ # Assuming the time dimension is the last dimension in the image data
42
+ time_dim = image_data.shape[-1]
43
+ column_names = [f'time_{i}' for i in range(time_dim)]
44
+ region_names = [f'region_{region}' for region in np.unique(atlas_data)]
45
+
46
+ # Reshape the image data to 2D (voxels x time)
47
+ reshaped_image_data = image_data.reshape(-1, time_dim)
48
+
49
+ # Create DataFrame with image data
50
+ df_times = pd.DataFrame(reshaped_image_data, columns=column_names)
51
+
52
+ # Reshape the atlas data to 1D (voxels)
53
+ reshaped_atlas_data = atlas_data.reshape(-1)
54
+
55
+ # Combine atlas regions with image data
56
+ df_full = pd.concat([pd.Series(reshaped_atlas_data, name='atlas_region'), df_times], axis=1)
57
+
58
+ # Group by atlas region and compute mean over time
59
+ regions_x_time = df_full.groupby('atlas_region').mean()
60
+ regions_x_time.index = region_names
61
+
62
+ # Flatten the region x time matrix to a vector
63
+ regions_x_time_vector = regions_x_time.to_numpy().reshape(-1)
64
+ return regions_x_time_vector
65
+
66
+ # Function to preprocess the input image and extract features
67
+ def preprocess_and_extract_features(nifti_data, atlas_data):
68
+ '''
69
+ Preprocess the input image data and extract features using the atlas.
70
+
71
+ Input:
72
+ - nifti_data: The NIfTI image data
73
+ - atlas_data: The atlas data
74
+
75
+ Output: Extracted feature vector
76
+ '''
77
+ features = image_to_vector(nifti_data, atlas_data)
78
+ num_required_features = 116
79
+
80
+ # If fewer features are found, pad with zeros; if more, truncate
81
+ if features.size < num_required_features:
82
+ features = np.pad(features, (0, num_required_features - features.size), 'constant')
83
+ else:
84
+ features = features[:num_required_features]
85
+
86
+ return features.reshape(1, -1)
87
+
88
+ def predict_region(input_file):
89
+ temp_file_path = None # Initialize temp_file_path to None
90
+ try:
91
+ # Create a temporary file with the correct extension
92
+ temp_file_path = input_file.name + ".nii.gz"
93
+ shutil.copy(input_file.name, temp_file_path)
94
+
95
+ # Load the NIfTI file and the atlas
96
+ img = nib.load(temp_file_path)
97
+ data = img.get_fdata()
98
+
99
+ # Path to the atlas file
100
+ atlas_filepath = 'aal_mask_pad.nii.gz' # Corrected file extension
101
+ if not os.path.exists(atlas_filepath):
102
+ raise FileNotFoundError(f"Atlas file not found at: {atlas_filepath}")
103
+
104
+ atlas_data = get_image_data(atlas_filepath)
105
+
106
+ # Preprocess and extract features
107
+ features = preprocess_and_extract_features(data, atlas_data)
108
+
109
+ # Predict using the loaded model
110
+ prediction = model.predict(features)
111
+ return str(prediction[0])
112
+ except Exception as e:
113
+ return f"Error: {e}"
114
+ finally:
115
+ # Clean up the temporary file
116
+ if temp_file_path and os.path.exists(temp_file_path):
117
+ os.remove(temp_file_path)
118
+
119
+
120
+ # Create Gradio interface
121
+ interface = gr.Interface(
122
+ fn=predict_region,
123
+ inputs=gr.File(label="Region Image (NIfTI file)"),
124
+ outputs="text",
125
+ title="Region Prediction",
126
+ description="Upload a region image in NIfTI format to get the prediction.",
127
+ allow_flagging="never" # Disable flagging
128
+ )
129
+
130
+ # Launch the Gradio interface
131
+ interface.launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio
2
+ nibabel
3
+ numpy
4
+ scikit-learn
5
+ pandas
svm_pipeline.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2cabd035b158794778139cfbefa402a17f4038ab2a8e26f6144956d68b8c8b8c
3
+ size 351512