tp53(ashish) commited on
Commit
231a8da
·
1 Parent(s): d53ae27

Add ZeroGPU decorator for GPU function detection

Browse files
Files changed (1) hide show
  1. app.py +74 -57
app.py CHANGED
@@ -83,72 +83,89 @@ def keep_largest_component(mask):
83
  return mask
84
 
85
 
86
- def segment_slice_api(
87
- nifti_b64: str,
88
- slice_idx: int,
89
- text_prompt: str = "tumor",
90
- checkpoint: str = "brain",
91
- ):
92
- """
93
- Segment a single slice from a NIfTI volume.
94
- Currently uses fallback segmentation (SAM3 to be integrated).
95
- """
96
- start_time = time.time()
97
-
98
- try:
99
- import nibabel as nib
100
-
101
- # Decode NIfTI
102
- nifti_bytes = base64.b64decode(nifti_b64)
103
-
104
- with tempfile.NamedTemporaryFile(suffix=".nii.gz", delete=False) as f:
105
- f.write(nifti_bytes)
106
- temp_path = f.name
107
-
108
- nii = nib.load(temp_path)
109
- volume = nii.get_fdata().astype(np.float32)
110
- os.unlink(temp_path)
111
 
112
- logger.info(
113
- f"Loaded volume shape: {volume.shape}, segmenting slice {slice_idx}"
114
- )
115
-
116
- # Validate slice index
117
- if slice_idx < 0 or slice_idx >= volume.shape[0]:
118
- return {
119
- "success": False,
120
- "error": f"Slice index {slice_idx} out of range [0, {volume.shape[0]})",
121
- }
122
 
123
- # Extract slice and segment
124
- slice_2d = volume[slice_idx]
125
- mask = fallback_segment(slice_2d)
126
- mask = keep_largest_component(mask)
127
 
128
- # Extract contours
129
- contours = find_contours(mask)
 
130
 
131
- # Encode mask as base64
132
- mask_b64 = base64.b64encode(mask.tobytes()).decode()
 
133
 
134
- inference_time = int((time.time() - start_time) * 1000)
135
- logger.info(
136
- f"Segmented slice {slice_idx} in {inference_time}ms, mask sum: {mask.sum()}"
137
- )
138
 
 
 
139
  return {
140
- "success": True,
141
- "backend": "fallback",
142
- "mask_b64": mask_b64,
143
- "mask_shape": list(mask.shape),
144
- "contours": contours,
145
- "slice_idx": slice_idx,
146
- "inference_time_ms": inference_time,
147
  }
148
 
149
- except Exception as e:
150
- logger.error(f"Segmentation failed: {e}")
151
- return {"success": False, "error": str(e)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
 
154
  def load_and_display_nifti(file):
 
83
  return mask
84
 
85
 
86
+ def _do_segment(nifti_b64, slice_idx, text_prompt, checkpoint):
87
+ """Core segmentation logic."""
88
+ import nibabel as nib
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
+ start_time = time.time()
 
 
 
 
 
 
 
 
 
91
 
92
+ # Decode NIfTI
93
+ nifti_bytes = base64.b64decode(nifti_b64)
 
 
94
 
95
+ with tempfile.NamedTemporaryFile(suffix=".nii.gz", delete=False) as f:
96
+ f.write(nifti_bytes)
97
+ temp_path = f.name
98
 
99
+ nii = nib.load(temp_path)
100
+ volume = nii.get_fdata().astype(np.float32)
101
+ os.unlink(temp_path)
102
 
103
+ logger.info(f"Loaded volume shape: {volume.shape}, segmenting slice {slice_idx}")
 
 
 
104
 
105
+ # Validate slice index
106
+ if slice_idx < 0 or slice_idx >= volume.shape[0]:
107
  return {
108
+ "success": False,
109
+ "error": f"Slice index {slice_idx} out of range [0, {volume.shape[0]})",
 
 
 
 
 
110
  }
111
 
112
+ # Extract slice and segment
113
+ slice_2d = volume[slice_idx]
114
+ mask = fallback_segment(slice_2d)
115
+ mask = keep_largest_component(mask)
116
+
117
+ # Extract contours
118
+ contours = find_contours(mask)
119
+
120
+ # Encode mask as base64
121
+ mask_b64 = base64.b64encode(mask.tobytes()).decode()
122
+
123
+ inference_time = int((time.time() - start_time) * 1000)
124
+ logger.info(
125
+ f"Segmented slice {slice_idx} in {inference_time}ms, mask sum: {mask.sum()}"
126
+ )
127
+
128
+ return {
129
+ "success": True,
130
+ "backend": "fallback",
131
+ "mask_b64": mask_b64,
132
+ "mask_shape": list(mask.shape),
133
+ "contours": contours,
134
+ "slice_idx": slice_idx,
135
+ "inference_time_ms": inference_time,
136
+ }
137
+
138
+
139
+ # Create GPU-decorated and non-GPU versions of the API
140
+ if ZEROGPU_AVAILABLE:
141
+
142
+ @spaces.GPU(duration=30)
143
+ def segment_slice_api(
144
+ nifti_b64: str,
145
+ slice_idx: int,
146
+ text_prompt: str = "tumor",
147
+ checkpoint: str = "brain",
148
+ ):
149
+ """Segment a single slice (GPU-accelerated when available)."""
150
+ try:
151
+ return _do_segment(nifti_b64, slice_idx, text_prompt, checkpoint)
152
+ except Exception as e:
153
+ logger.error(f"Segmentation failed: {e}")
154
+ return {"success": False, "error": str(e)}
155
+ else:
156
+
157
+ def segment_slice_api(
158
+ nifti_b64: str,
159
+ slice_idx: int,
160
+ text_prompt: str = "tumor",
161
+ checkpoint: str = "brain",
162
+ ):
163
+ """Segment a single slice."""
164
+ try:
165
+ return _do_segment(nifti_b64, slice_idx, text_prompt, checkpoint)
166
+ except Exception as e:
167
+ logger.error(f"Segmentation failed: {e}")
168
+ return {"success": False, "error": str(e)}
169
 
170
 
171
  def load_and_display_nifti(file):