File size: 9,476 Bytes
fa5d00b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
"""
DICOM utilities for processing medical imaging studies.
"""
import io
import zipfile
from typing import List, Tuple, Dict, Optional
import numpy as np
from PIL import Image
import pydicom


def has_pixel_data(ds: pydicom.Dataset) -> bool:
    """Check if DICOM dataset has pixel data."""
    return (
        'PixelData' in ds or 
        'FloatPixelData' in ds or 
        'DoubleFloatPixelData' in ds
    )


def extract_dicom_from_zip(zip_bytes: bytes) -> List[Tuple[str, pydicom.Dataset]]:
    """Extract DICOM files from a ZIP archive, filtering out non-image files."""
    dicom_files = []
    
    with zipfile.ZipFile(io.BytesIO(zip_bytes), 'r') as zip_ref:
        for filename in zip_ref.namelist():
            if filename.lower().endswith('.dcm'):
                try:
                    file_bytes = zip_ref.read(filename)
                    ds = pydicom.dcmread(io.BytesIO(file_bytes))
                    
                    # Skip files without pixel data (SR, reports, dose records, etc.)
                    if has_pixel_data(ds):
                        dicom_files.append((filename, ds))
                    else:
                        print(f"Skipping {filename}: No pixel data (likely SR or report)")
                        
                except Exception as e:
                    print(f"Error reading {filename}: {e}")
    
    return dicom_files


def get_modality(ds: pydicom.Dataset) -> str:
    return getattr(ds, 'Modality', 'Unknown')


def get_study_info(ds: pydicom.Dataset, total_slices: int) -> Dict:
    return {
        'StudyInstanceUID': getattr(ds, 'StudyInstanceUID', 'Unknown'),
        'StudyDescription': getattr(ds, 'StudyDescription', 'Unknown'),
        'Modality': get_modality(ds),
        'TotalSlices': total_slices,
        'StudyDate': getattr(ds, 'StudyDate', 'Unknown'),
        'PatientID': getattr(ds, 'PatientID', 'Unknown'),
    }


def get_default_window(ds: pydicom.Dataset) -> Tuple[Optional[float], Optional[float]]:
    """Get default window center and width from DICOM metadata."""
    wc = getattr(ds, 'WindowCenter', None)
    ww = getattr(ds, 'WindowWidth', None)

    # Handle multi-valued windows (take first)
    if wc is not None:
        wc = float(wc[0]) if hasattr(wc, '__iter__') and not isinstance(wc, str) else float(wc)
    if ww is not None:
        ww = float(ww[0]) if hasattr(ww, '__iter__') and not isinstance(ww, str) else float(ww)

    return wc, ww


def apply_windowing(
    pixel_array: np.ndarray,
    ds: pydicom.Dataset,
    window_center: Optional[float] = None,
    window_width: Optional[float] = None
) -> np.ndarray:
    """Apply rescale slope/intercept and windowing to pixel array."""
    # Apply rescale slope and intercept (converts to HU for CT)
    slope = getattr(ds, 'RescaleSlope', 1)
    intercept = getattr(ds, 'RescaleIntercept', 0)
    pixel_array = pixel_array.astype(np.float32) * slope + intercept

    # Get window values
    if window_center is None or window_width is None:
        default_wc, default_ww = get_default_window(ds)
        if window_center is None:
            window_center = default_wc
        if window_width is None:
            window_width = default_ww

    # Apply windowing if we have valid values
    if window_center is not None and window_width is not None and window_width > 0:
        min_val = window_center - window_width / 2
        max_val = window_center + window_width / 2
        pixel_array = np.clip(pixel_array, min_val, max_val)
        normalized = ((pixel_array - min_val) / (max_val - min_val) * 255).astype(np.uint8)
    else:
        # Fallback: normalize to full range
        pixel_min = pixel_array.min()
        pixel_max = pixel_array.max()
        if pixel_max > pixel_min:
            normalized = ((pixel_array - pixel_min) / (pixel_max - pixel_min) * 255).astype(np.uint8)
        else:
            normalized = np.zeros_like(pixel_array, dtype=np.uint8)

    return normalized


def dicom_to_pil(
    ds: pydicom.Dataset,
    size: Tuple[int, int] = (896, 896),
    window_center: Optional[float] = None,
    window_width: Optional[float] = None
) -> Image.Image:
    """Convert DICOM dataset to PIL Image with optional windowing and resizing."""
    pixel_array = ds.pixel_array
    normalized = apply_windowing(pixel_array, ds, window_center, window_width)

    if len(normalized.shape) == 2:
        pil_image = Image.fromarray(normalized, mode='L')
    elif len(normalized.shape) == 3 and normalized.shape[2] <= 4:
        if normalized.shape[2] == 1:
            pil_image = Image.fromarray(normalized[:, :, 0], mode='L')
        elif normalized.shape[2] == 3:
            pil_image = Image.fromarray(normalized, mode='RGB')
        elif normalized.shape[2] == 4:
            pil_image = Image.fromarray(normalized[:, :, :3], mode='RGB')
        else:
            pil_image = Image.fromarray(normalized[:, :, 0], mode='L')
    else:
        pil_image = Image.fromarray(normalized[0], mode='L')

    if pil_image.mode != 'RGB':
        pil_image = pil_image.convert('RGB')

    pil_image = pil_image.resize(size, Image.LANCZOS)

    return pil_image


def organize_by_series(dicom_files: List[Tuple[str, pydicom.Dataset]]) -> Dict[str, List[Tuple[str, pydicom.Dataset]]]:
    series_dict = {}
    for filename, ds in dicom_files:
        series_uid = getattr(ds, 'SeriesInstanceUID', 'Unknown')
        if series_uid not in series_dict:
            series_dict[series_uid] = []
        series_dict[series_uid].append((filename, ds))
    return series_dict


def sort_slices_by_position(series_files: List[Tuple[str, pydicom.Dataset]]) -> List[Tuple[str, pydicom.Dataset]]:
    def get_sort_key(item):
        filename, ds = item
        instance_num = getattr(ds, 'InstanceNumber', None)
        if instance_num is not None:
            return (0, int(instance_num))
        
        slice_loc = getattr(ds, 'SliceLocation', None)
        if slice_loc is not None:
            return (1, float(slice_loc))
        
        return (2, filename)
    
    return sorted(series_files, key=get_sort_key)


def sample_slices_evenly(all_slices: List[Tuple[str, pydicom.Dataset]], max_slices: int = 500) -> List[Tuple[str, pydicom.Dataset]]:
    if len(all_slices) <= max_slices:
        return all_slices
    
    indices = [int(i * (len(all_slices) - 1) / (max_slices - 1)) for i in range(max_slices)]
    return [all_slices[i] for i in indices]


def process_dicom_study(
    zip_bytes: bytes,
    max_slices: int = 500,
    max_slices_per_series: Optional[int] = None,
    image_size: int = 896,
    window_center: Optional[float] = None,
    window_width: Optional[float] = None
) -> Tuple[str, List[Image.Image], Dict]:
    """
    Process a DICOM study from a ZIP file.

    Args:
        zip_bytes: ZIP file contents
        max_slices: Maximum total slices across all series (used if max_slices_per_series is None)
        max_slices_per_series: If set, sample this many slices evenly from each series
        image_size: Output image size (square, e.g., 896 for 896x896)
        window_center: Window center for display (None = use DICOM default or auto)
        window_width: Window width for display (None = use DICOM default or auto)
    """
    dicom_files = extract_dicom_from_zip(zip_bytes)

    if not dicom_files:
        raise ValueError("No valid DICOM files found in the ZIP archive")

    first_ds = dicom_files[0][1]
    modality = get_modality(first_ds)

    # Get default window from first image
    default_wc, default_ww = get_default_window(first_ds)

    series_dict = organize_by_series(dicom_files)

    # Count total original slices
    total_original_slices = sum(len(files) for files in series_dict.values())

    # Sample slices per series or globally
    sampled_slices = []
    if max_slices_per_series is not None:
        # Sample evenly from each series
        for series_uid, series_files in series_dict.items():
            sorted_slices = sort_slices_by_position(series_files)
            series_sampled = sample_slices_evenly(sorted_slices, max_slices_per_series)
            sampled_slices.extend(series_sampled)
    else:
        # Original behavior: sample globally
        all_sorted_slices = []
        for series_uid, series_files in series_dict.items():
            sorted_slices = sort_slices_by_position(series_files)
            all_sorted_slices.extend(sorted_slices)
        sampled_slices = sample_slices_evenly(all_sorted_slices, max_slices)

    sampled_count = len(sampled_slices)

    study_info = get_study_info(first_ds, sampled_count)
    study_info['SeriesCount'] = len(series_dict)
    study_info['TotalOriginalSlices'] = total_original_slices
    study_info['SampledSlices'] = sampled_count
    study_info['ImageSize'] = image_size
    study_info['DefaultWindowCenter'] = default_wc
    study_info['DefaultWindowWidth'] = default_ww
    if max_slices_per_series is not None:
        study_info['MaxSlicesPerSeries'] = max_slices_per_series

    images = []
    for filename, ds in sampled_slices:
        try:
            pil_image = dicom_to_pil(
                ds,
                size=(image_size, image_size),
                window_center=window_center,
                window_width=window_width
            )
            images.append(pil_image)
        except Exception as e:
            print(f"Error converting {filename}: {e}")

    study_info['ProcessedImages'] = len(images)

    return modality, images, study_info