Aakash-Tripathi commited on
Commit
7ccd367
·
verified ·
1 Parent(s): 91ebb5b

Delete image_processing_sybil.py

Browse files
Files changed (1) hide show
  1. image_processing_sybil.py +0 -315
image_processing_sybil.py DELETED
@@ -1,315 +0,0 @@
1
- """Image processor for Sybil CT scan preprocessing"""
2
-
3
- import numpy as np
4
- import torch
5
- from typing import Dict, List, Optional, Union, Tuple
6
- from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
7
- from transformers.utils import TensorType
8
- import pydicom
9
- from PIL import Image
10
- import torchio as tio
11
-
12
-
13
- def order_slices(dicoms: List) -> List:
14
- """Order DICOM slices by their position"""
15
- # Sort by ImagePositionPatient if available
16
- try:
17
- dicoms = sorted(dicoms, key=lambda x: float(x.ImagePositionPatient[2]))
18
- except (AttributeError, TypeError):
19
- # Fall back to InstanceNumber if ImagePositionPatient not available
20
- try:
21
- dicoms = sorted(dicoms, key=lambda x: int(x.InstanceNumber))
22
- except (AttributeError, TypeError):
23
- pass # Keep original order if neither attribute is available
24
- return dicoms
25
-
26
-
27
- class SybilImageProcessor(BaseImageProcessor):
28
- """
29
- Constructs a Sybil image processor for preprocessing CT scans.
30
-
31
- Args:
32
- voxel_spacing (`List[float]`, *optional*, defaults to `[0.703125, 0.703125, 2.5]`):
33
- Target voxel spacing for resampling (row, column, slice thickness).
34
- img_size (`List[int]`, *optional*, defaults to `[512, 512]`):
35
- Target image size after resizing.
36
- num_images (`int`, *optional*, defaults to `208`):
37
- Number of slices to use from the CT scan.
38
- windowing (`Dict[str, float]`, *optional*):
39
- Windowing parameters for CT scan visualization.
40
- Default uses lung window: center=-600, width=1500.
41
- normalize (`bool`, *optional*, defaults to `True`):
42
- Whether to normalize pixel values to [0, 1].
43
- **kwargs:
44
- Additional keyword arguments passed to the parent class.
45
- """
46
-
47
- model_input_names = ["pixel_values"]
48
-
49
- def __init__(
50
- self,
51
- voxel_spacing: List[float] = None,
52
- img_size: List[int] = None,
53
- num_images: int = 208,
54
- windowing: Dict[str, float] = None,
55
- normalize: bool = True,
56
- **kwargs
57
- ):
58
- super().__init__(**kwargs)
59
-
60
- self.voxel_spacing = voxel_spacing if voxel_spacing is not None else [0.703125, 0.703125, 2.5]
61
- self.img_size = img_size if img_size is not None else [512, 512]
62
- self.num_images = num_images
63
-
64
- # Default lung window settings
65
- self.windowing = windowing if windowing is not None else {
66
- "center": -600,
67
- "width": 1500
68
- }
69
- self.normalize = normalize
70
-
71
- # TorchIO transforms for standardization
72
- self.resample_transform = tio.transforms.Resample(target=self.voxel_spacing)
73
- # Note: Original Sybil uses 200 depth, 256x256 images
74
- self.default_depth = 200
75
- self.default_size = [256, 256]
76
- self.padding_transform = tio.transforms.CropOrPad(
77
- target_shape=(self.default_depth, *self.default_size),
78
- padding_mode=0
79
- )
80
-
81
- def load_dicom_series(self, paths: List[str]) -> Tuple[np.ndarray, Dict]:
82
- """
83
- Load a series of DICOM files.
84
-
85
- Args:
86
- paths: List of paths to DICOM files.
87
-
88
- Returns:
89
- Tuple of (volume array, metadata dict)
90
- """
91
- dicoms = []
92
- for path in paths:
93
- try:
94
- dcm = pydicom.dcmread(path, stop_before_pixels=False)
95
- dicoms.append(dcm)
96
- except Exception as e:
97
- print(f"Error reading DICOM file {path}: {e}")
98
- continue
99
-
100
- if not dicoms:
101
- raise ValueError("No valid DICOM files found")
102
-
103
- # Order slices by position
104
- dicoms = order_slices(dicoms)
105
-
106
- # Extract pixel arrays
107
- volume = np.stack([dcm.pixel_array.astype(np.float32) for dcm in dicoms])
108
-
109
- # Extract metadata
110
- metadata = {
111
- "slice_thickness": float(dicoms[0].SliceThickness) if hasattr(dicoms[0], 'SliceThickness') else None,
112
- "pixel_spacing": list(map(float, dicoms[0].PixelSpacing)) if hasattr(dicoms[0], 'PixelSpacing') else None,
113
- "manufacturer": str(dicoms[0].Manufacturer) if hasattr(dicoms[0], 'Manufacturer') else None,
114
- "num_slices": len(dicoms)
115
- }
116
-
117
- # Apply rescale if present
118
- if hasattr(dicoms[0], 'RescaleSlope') and hasattr(dicoms[0], 'RescaleIntercept'):
119
- slope = float(dicoms[0].RescaleSlope)
120
- intercept = float(dicoms[0].RescaleIntercept)
121
- volume = volume * slope + intercept
122
-
123
- return volume, metadata
124
-
125
- def load_png_series(self, paths: List[str]) -> np.ndarray:
126
- """
127
- Load a series of PNG files.
128
-
129
- Args:
130
- paths: List of paths to PNG files (must be in anatomical order).
131
-
132
- Returns:
133
- 3D volume array
134
- """
135
- images = []
136
- for path in paths:
137
- img = Image.open(path).convert('L') # Convert to grayscale
138
- images.append(np.array(img, dtype=np.float32))
139
-
140
- return np.stack(images)
141
-
142
- def apply_windowing(self, volume: np.ndarray) -> np.ndarray:
143
- """
144
- Apply windowing to CT scan for better visualization.
145
-
146
- Args:
147
- volume: 3D CT volume.
148
-
149
- Returns:
150
- Windowed volume.
151
- """
152
- center = self.windowing["center"]
153
- width = self.windowing["width"]
154
-
155
- # Calculate window boundaries
156
- lower = center - width / 2
157
- upper = center + width / 2
158
-
159
- # Apply windowing
160
- volume = np.clip(volume, lower, upper)
161
-
162
- # Normalize to [0, 1] if requested
163
- if self.normalize:
164
- volume = (volume - lower) / (upper - lower)
165
-
166
- return volume
167
-
168
- def resample_volume(
169
- self,
170
- volume: torch.Tensor,
171
- original_spacing: Optional[List[float]] = None
172
- ) -> torch.Tensor:
173
- """
174
- Resample volume to target voxel spacing.
175
-
176
- Args:
177
- volume: 3D volume tensor.
178
- original_spacing: Original voxel spacing.
179
-
180
- Returns:
181
- Resampled volume.
182
- """
183
- # Create TorchIO subject
184
- subject = tio.Subject(
185
- image=tio.ScalarImage(tensor=volume.unsqueeze(0), spacing=original_spacing)
186
- )
187
-
188
- # Apply resampling
189
- resampled = self.resample_transform(subject)
190
-
191
- return resampled['image'].data.squeeze(0)
192
-
193
- def pad_or_crop_volume(self, volume: torch.Tensor) -> torch.Tensor:
194
- """
195
- Pad or crop volume to target shape.
196
-
197
- Args:
198
- volume: 3D volume tensor.
199
-
200
- Returns:
201
- Padded/cropped volume.
202
- """
203
- # Create TorchIO subject
204
- subject = tio.Subject(
205
- image=tio.ScalarImage(tensor=volume.unsqueeze(0))
206
- )
207
-
208
- # Apply padding/cropping
209
- transformed = self.padding_transform(subject)
210
-
211
- return transformed['image'].data.squeeze(0)
212
-
213
- def preprocess(
214
- self,
215
- images: Union[List[str], np.ndarray, torch.Tensor],
216
- file_type: str = "dicom",
217
- voxel_spacing: Optional[List[float]] = None,
218
- return_tensors: Optional[Union[str, TensorType]] = None,
219
- **kwargs
220
- ) -> BatchFeature:
221
- """
222
- Preprocess CT scan images.
223
-
224
- Args:
225
- images: Either list of file paths or numpy/torch array of images.
226
- file_type: Type of input files ("dicom" or "png").
227
- voxel_spacing: Original voxel spacing (required for PNG files).
228
- return_tensors: The type of tensors to return.
229
-
230
- Returns:
231
- BatchFeature with preprocessed images.
232
- """
233
- # Load images if paths are provided
234
- if isinstance(images, list) and isinstance(images[0], str):
235
- if file_type == "dicom":
236
- volume, metadata = self.load_dicom_series(images)
237
- if voxel_spacing is None and metadata["pixel_spacing"]:
238
- voxel_spacing = metadata["pixel_spacing"] + [metadata["slice_thickness"]]
239
- elif file_type == "png":
240
- if voxel_spacing is None:
241
- raise ValueError("voxel_spacing must be provided for PNG files")
242
- volume = self.load_png_series(images)
243
- else:
244
- raise ValueError(f"Unknown file type: {file_type}")
245
- elif isinstance(images, (np.ndarray, torch.Tensor)):
246
- volume = images
247
- else:
248
- raise ValueError("Images must be file paths, numpy array, or torch tensor")
249
-
250
- # Convert to torch tensor
251
- if isinstance(volume, np.ndarray):
252
- volume = torch.from_numpy(volume).float()
253
-
254
- # Apply windowing
255
- if isinstance(volume, torch.Tensor):
256
- volume_np = volume.numpy()
257
- else:
258
- volume_np = volume
259
- volume_np = self.apply_windowing(volume_np)
260
- volume = torch.from_numpy(volume_np).float()
261
-
262
- # Resample if spacing is provided
263
- if voxel_spacing is not None:
264
- volume = self.resample_volume(volume, voxel_spacing)
265
-
266
- # Pad or crop to target shape
267
- volume = self.pad_or_crop_volume(volume)
268
-
269
- # Reshape to match original Sybil format: (D, H, W) -> (C, D, H, W)
270
- # The model expects 3 channels (RGB format), so repeat grayscale to 3 channels
271
- volume = volume.unsqueeze(0).repeat(3, 1, 1, 1) # Now (3, D, H, W)
272
-
273
- # Prepare output
274
- data = {"pixel_values": volume}
275
-
276
- # Convert to requested tensor type
277
- if return_tensors == "pt":
278
- return BatchFeature(data=data, tensor_type=TensorType.PYTORCH)
279
- elif return_tensors == "np":
280
- data = {k: v.numpy() for k, v in data.items()}
281
- return BatchFeature(data=data, tensor_type=TensorType.NUMPY)
282
- else:
283
- return BatchFeature(data=data)
284
-
285
- def __call__(
286
- self,
287
- images: Union[List[str], List[List[str]], np.ndarray, torch.Tensor],
288
- **kwargs
289
- ) -> BatchFeature:
290
- """
291
- Main method to prepare images for the model.
292
-
293
- Args:
294
- images: Images to preprocess. Can be:
295
- - List of file paths for a single series
296
- - List of lists of file paths for multiple series
297
- - Numpy array or torch tensor
298
-
299
- Returns:
300
- BatchFeature with preprocessed images ready for model input.
301
- """
302
- # Handle batch processing
303
- if isinstance(images, list) and images and isinstance(images[0], list):
304
- # Multiple series
305
- batch_volumes = []
306
- for series_paths in images:
307
- result = self.preprocess(series_paths, **kwargs)
308
- batch_volumes.append(result["pixel_values"])
309
-
310
- # Stack into batch (B, C, D, H, W)
311
- pixel_values = torch.stack(batch_volumes)
312
- return BatchFeature(data={"pixel_values": pixel_values})
313
- else:
314
- # Single series
315
- return self.preprocess(images, **kwargs)