File size: 13,029 Bytes
69066c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77909e8
69066c5
 
 
 
77909e8
69066c5
 
 
 
77909e8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69066c5
 
 
 
 
 
 
 
 
 
77909e8
 
69066c5
77909e8
69066c5
 
77909e8
69066c5
 
 
77909e8
0961cf5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77909e8
 
 
69066c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0961cf5
69066c5
 
 
 
 
 
 
 
 
 
 
 
0961cf5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69066c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
"""
Utility functions for NeuroSAM 3 application.
Helper functions for image processing, visualization, and common operations.
"""

from typing import Optional, Tuple, List, Dict, Any
import os
import re
import tempfile
import numpy as np
import pydicom
from PIL import Image
import matplotlib.pyplot as plt
from logger_config import logger


def extract_subject_id(file_path: str) -> Tuple[str, str, str]:
    """
    Extract subject/patient ID from file path.
    
    Common patterns:
    - Folder name: /subject_001/image.png -> subject_001
    - Filename prefix: subject_001_slice_01.png -> subject_001
    - Patient ID in filename: patient_123_slice_5.dcm -> patient_123
    - Study UID in DICOM: extract from DICOM metadata
    
    Args:
        file_path: Path to file
    
    Returns:
        Tuple of (subject_id, confidence_level, source)
        confidence_level: 'high' (DICOM metadata), 'medium' (folder/filename pattern), 'low' (fallback)
        source: 'dicom_patientid', 'dicom_study', 'folder', 'filename', 'fallback'
    """
    file_path = str(file_path)
    filename = os.path.basename(file_path)
    dir_path = os.path.dirname(file_path)
    
    # HIGHEST CONFIDENCE: DICOM metadata (most reliable)
    if file_path.lower().endswith('.dcm'):
        try:
            ds = pydicom.dcmread(file_path, stop_before_pixels=True)
            patient_id = getattr(ds, 'PatientID', None)
            if patient_id and patient_id.strip():
                return f"patient_{patient_id}", 'high', 'dicom_patientid'
            
            study_uid = getattr(ds, 'StudyInstanceUID', None)
            if study_uid:
                # Use full study UID as identifier (unique per study)
                return f"study_{study_uid}", 'high', 'dicom_study'
        except Exception as e:
            logger.debug(f"Could not read DICOM metadata: {e}")
    
    # MEDIUM CONFIDENCE: Folder name (common in medical datasets)
    folder_name = os.path.basename(dir_path.rstrip('/'))
    if folder_name and folder_name not in ['', '.', '..']:
        # Check if folder name looks like a subject ID
        if re.match(r'(subject|patient|sub|pat|case|id)[_-]?\d+', folder_name, re.I):
            return folder_name, 'medium', 'folder'
    
    # MEDIUM CONFIDENCE: Filename pattern
    patterns = [
        (r'(subject|patient|sub|pat|case|id)[_-]?(\d+)', 'medium'),  # subject_001, patient_123
        (r'([A-Z]{2,}\d+)', 'medium'),  # BR001, MR123, etc.
    ]
    
    for pattern, confidence in patterns:
        match = re.search(pattern, filename, re.I)
        if match:
            if len(match.groups()) > 1:
                return f"{match.group(1)}_{match.group(2)}", confidence, 'filename'
            else:
                return match.group(1), confidence, 'filename'
    
    # LOW CONFIDENCE: Numeric pattern (could be slice number, not patient ID)
    numeric_match = re.search(r'(\d{3,})', filename)
    if numeric_match:
        return numeric_match.group(1), 'low', 'filename_numeric'
    
    # LOWEST CONFIDENCE: Fallback to filename
    base_name = os.path.splitext(filename)[0]
    if len(base_name) > 0:
        return base_name, 'low', 'fallback'
    
    return "unknown", 'low', 'unknown'


def group_images_by_subject(image_files: List[str]) -> Dict[str, Dict[str, Any]]:
    """
    Group image files by subject/patient ID.
    
    Args:
        image_files: List of file paths
    
    Returns:
        Dictionary: {subject_id: {'files': [...], 'confidence': 'high/medium/low', 'sources': set(...)}}
    """
    if not image_files:
        return {}
    
    if isinstance(image_files, str):
        image_files = [image_files]
    
    # Filter out None files
    image_files = [f for f in image_files if f is not None]
    
    # Group by subject ID and track confidence
    subject_groups = {}
    for file_path in image_files:
        subject_id, confidence, source = extract_subject_id(file_path)
        
        if subject_id not in subject_groups:
            subject_groups[subject_id] = {
                'files': [],
                'confidence': confidence,
                'sources': set([source])
            }
        
        subject_groups[subject_id]['files'].append(file_path)
        subject_groups[subject_id]['sources'].add(source)
        
        # Upgrade confidence if we find high-confidence source
        if confidence == 'high' or (confidence == 'medium' and subject_groups[subject_id]['confidence'] == 'low'):
            subject_groups[subject_id]['confidence'] = confidence
    
    # Sort files within each group (by filename)
    for subject_id in subject_groups:
        subject_groups[subject_id]['files'].sort()
        subject_groups[subject_id]['sources'] = list(subject_groups[subject_id]['sources'])
    
    return subject_groups


def combine_masks(masks) -> Optional[np.ndarray]:
    """
    Combine multiple mask arrays into a single mask.
    
    Args:
        masks: List of mask arrays, or numpy array, or None
    
    Returns:
        Combined mask array or None if no valid masks
    """
    if masks is None:
        return None
    
    # Handle numpy array input (convert to list)
    if isinstance(masks, np.ndarray):
        if masks.ndim == 0:  # Scalar
            return None
        elif masks.ndim == 1:  # 1D array - might be empty
            if len(masks) == 0:
                return None
            masks = [masks]  # Convert to list
        else:  # Multi-dimensional array - treat as single mask
            return masks
    
    # Handle list/tuple input
    if not isinstance(masks, (list, tuple)):
        # Try to convert to list
        try:
            masks = list(masks)
        except Exception:
            return None
    
    if len(masks) == 0:
        return None
    
    mask_arrays = []
    for mask in masks:
        if isinstance(mask, np.ndarray):
            mask_arrays.append(mask)
        else:
            # Try to convert to numpy
            try:
                mask_np = np.array(mask)
                if mask_np.size > 0:  # Only add non-empty arrays
                    mask_arrays.append(mask_np)
            except Exception as e:
                logger.debug(f"Could not convert mask to numpy: {e}")
                continue
    
    if len(mask_arrays) == 0:
        return None
    
    # Combine all masks using logical OR
    try:
        # Ensure all masks have the same shape and are 2D
        # First, convert any 3D masks to 2D
        mask_arrays_2d = []
        for mask in mask_arrays:
            if mask.ndim == 3:
                # If 3D, take first channel or convert to grayscale
                if mask.shape[0] == 3 or mask.shape[2] == 3:
                    if mask.shape[0] == 3:
                        mask = np.mean(mask, axis=0) > 0.5
                    else:
                        mask = np.mean(mask, axis=2) > 0.5
                else:
                    mask = mask[0] if mask.shape[0] < mask.shape[2] else mask[:, :, 0]
            elif mask.ndim > 3:
                mask = mask.squeeze()
                if mask.ndim != 2:
                    mask = mask.reshape(mask.shape[-2], mask.shape[-1])
            
            # Ensure boolean
            if mask.dtype != bool:
                mask = mask.astype(bool) if mask.max() <= 1 else (mask > mask.max() / 2)
            
            mask_arrays_2d.append(mask)
        
        # Resize masks to same shape if needed
        if len(mask_arrays_2d) > 1:
            target_shape = mask_arrays_2d[0].shape
            for i in range(1, len(mask_arrays_2d)):
                if mask_arrays_2d[i].shape != target_shape:
                    from scipy.ndimage import zoom
                    zoom_factors = (
                        target_shape[0] / mask_arrays_2d[i].shape[0],
                        target_shape[1] / mask_arrays_2d[i].shape[1]
                    )
                    mask_arrays_2d[i] = zoom(mask_arrays_2d[i].astype(float), zoom_factors, order=0) > 0.5
        
        combined_mask = np.any(mask_arrays_2d, axis=0)
        return combined_mask.astype(bool)
    except Exception as e:
        logger.error(f"Error combining masks: {e}", exc_info=True)
        return None


def create_output_image(
    pil_image: Image.Image,
    mask: Optional[np.ndarray],
    prompt_text: str,
    colormap: str = 'spring',
    transparency: float = 0.5,
    title: Optional[str] = None
) -> str:
    """
    Create output visualization image with mask overlay.
    
    Args:
        pil_image: Base PIL image
        mask: Optional mask array to overlay (2D or 3D)
        prompt_text: Prompt text for title
        colormap: Matplotlib colormap name
        transparency: Mask transparency (0.0-1.0)
        title: Optional custom title
    
    Returns:
        Path to saved output image
    """
    plt.figure(figsize=(10, 10))
    plt.imshow(pil_image)
    
    if mask is not None:
        # Ensure mask is 2D for matplotlib imshow
        if isinstance(mask, np.ndarray):
            if mask.ndim == 3:
                # If 3D, take first channel or convert to grayscale
                if mask.shape[0] == 3 or mask.shape[2] == 3:
                    # RGB-like format: convert to grayscale
                    if mask.shape[0] == 3:
                        # Shape is (3, H, W) - take mean across channels
                        mask = np.mean(mask, axis=0)
                    else:
                        # Shape is (H, W, 3) - convert to grayscale
                        mask = np.mean(mask, axis=2)
                else:
                    # Take first channel
                    mask = mask[0] if mask.shape[0] < mask.shape[2] else mask[:, :, 0]
            elif mask.ndim > 3:
                # Flatten extra dimensions
                mask = mask.squeeze()
                if mask.ndim != 2:
                    logger.warning(f"Mask has {mask.ndim} dimensions, expected 2D. Flattening...")
                    mask = mask.reshape(mask.shape[-2], mask.shape[-1])
            
            # Ensure mask is boolean or binary (0-1 range)
            if mask.dtype != bool:
                # Convert to boolean if not already
                mask = mask.astype(bool) if mask.max() <= 1 else (mask > mask.max() / 2)
        
        plt.imshow(mask, alpha=transparency, cmap=colormap)
    
    plt.axis('off')
    display_title = title or f"Segmentation: {prompt_text}"
    plt.title(display_title, fontsize=12, pad=10)
    
    output_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
    output_path = output_file.name
    output_file.close()
    
    from config import OUTPUT_DPI
    plt.savefig(output_path, bbox_inches='tight', pad_inches=0, dpi=OUTPUT_DPI)
    plt.close()
    
    return output_path


def create_demo_dicom_file(output_path: str = "demo_brain_mri.dcm") -> bool:
    """
    Create a demo DICOM file for testing.
    
    Args:
        output_path: Path where to save the demo file
    
    Returns:
        True if successful, False otherwise
    """
    try:
        from pydicom.data import get_testdata_file
        test_file = get_testdata_file("MR_small.dcm")
        if test_file and os.path.exists(test_file):
            import shutil
            shutil.copy(test_file, output_path)
            logger.info(f"Demo file ready: {output_path}")
            return True
    except Exception as e:
        logger.debug(f"Could not copy test DICOM file: {e}")
    
    try:
        # Create synthetic DICOM file
        from pydicom.dataset import FileDataset, FileMetaDataset
        from pydicom.uid import generate_uid
        
        synthetic_image = np.random.randint(0, 255, (256, 256), dtype=np.uint16)
        center_x, center_y = 128, 128
        y, x = np.ogrid[:256, :256]
        mask = (x - center_x)**2 + (y - center_y)**2 <= 100**2
        synthetic_image[mask] = np.clip(synthetic_image[mask] + 50, 0, 255)
        
        file_meta = FileMetaDataset()
        file_meta.MediaStorageSOPClassUID = '1.2.840.10008.5.1.4.1.1.4'
        file_meta.MediaStorageSOPInstanceUID = generate_uid()
        file_meta.TransferSyntaxUID = '1.2.840.10008.1.2.1'
        
        ds = FileDataset(output_path, {}, file_meta=file_meta, preamble=b"\x00" * 128)
        ds.PatientName = "Demo^Patient"
        ds.PatientID = "DEMO001"
        ds.Modality = "MR"
        ds.Rows = 256
        ds.Columns = 256
        ds.BitsAllocated = 16
        ds.BitsStored = 16
        ds.HighBit = 15
        ds.SamplesPerPixel = 1
        ds.PixelRepresentation = 0
        ds.PhotometricInterpretation = "MONOCHROME2"
        ds.PixelSpacing = [1.0, 1.0]
        ds.RescaleIntercept = "0"
        ds.RescaleSlope = "1"
        ds.PixelData = synthetic_image.tobytes()
        
        ds.save_as(output_path, write_like_original=False)
        logger.info(f"Synthetic demo file created: {output_path}")
        return True
        
    except Exception as e:
        logger.warning(f"Could not create demo file: {e}")
        return False