File size: 23,644 Bytes
da6b0fd
7b779c6
da6b0fd
 
af3fe62
 
da6b0fd
 
 
 
 
 
 
 
7b779c6
da6b0fd
 
 
 
7b779c6
 
da6b0fd
7b779c6
af3fe62
 
da6b0fd
af3fe62
 
da6b0fd
af3fe62
da6b0fd
af3fe62
da6b0fd
af3fe62
da6b0fd
 
 
 
 
 
7b779c6
 
af3fe62
25616b8
7b779c6
 
 
 
 
da6b0fd
9e3d421
7b779c6
 
 
 
9e3d421
7b779c6
 
 
9e3d421
7b779c6
 
9e3d421
7b779c6
 
 
 
af3fe62
da6b0fd
7b779c6
da6b0fd
 
 
 
7b779c6
 
da6b0fd
 
 
 
 
7b779c6
 
da6b0fd
 
7b779c6
 
 
 
 
 
 
 
 
 
da6b0fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7b779c6
da6b0fd
3b63b33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da6b0fd
 
3b63b33
da6b0fd
3b63b33
 
 
 
 
 
 
 
 
 
af3fe62
da6b0fd
67d93b2
da6b0fd
 
 
 
 
7b779c6
 
 
67d93b2
da6b0fd
 
 
 
 
 
 
 
 
 
 
 
 
9e3d421
7b779c6
3b63b33
 
 
 
 
 
9e3d421
3b63b33
 
 
 
 
 
9e3d421
7b779c6
3b63b33
 
7b779c6
3b63b33
 
7b779c6
da6b0fd
 
3b63b33
 
 
af3fe62
3b63b33
7b779c6
 
3b63b33
 
 
 
da6b0fd
3b63b33
 
af3fe62
3b63b33
 
 
 
 
 
da6b0fd
af3fe62
3b63b33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da6b0fd
3b63b33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da6b0fd
 
3b63b33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e3d421
3b63b33
 
9e3d421
3b63b33
 
 
9e3d421
3b63b33
 
 
 
 
9e3d421
3b63b33
 
 
 
 
 
 
9e3d421
3b63b33
 
 
9e3d421
3b63b33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e3d421
3b63b33
 
 
 
 
 
 
 
 
 
 
 
 
 
9e3d421
3b63b33
 
 
 
 
9e3d421
3b63b33
9e3d421
3b63b33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e3d421
 
3b63b33
 
 
 
 
 
 
9e3d421
3b63b33
 
 
 
9e3d421
3b63b33
 
 
 
 
9e3d421
3b63b33
 
 
 
 
9e3d421
3b63b33
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
"""
Collection of various utils
"""

import numpy as np

import imageio.v3 as iio
from PIL import Image
# we may have very large images (e.g. panoramic SEM images), allow to read them w/o warnings
Image.MAX_IMAGE_PIXELS = 933120000

import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.lines import Line2D
import logging # ADDED for logging
import math


###
### load SEM images (Note: Not directly used with Gradio gr.Image(type="pil"))
###
def load_image(filename : str) -> np.ndarray :
    """Load an SEM image

    Args:
        filename (str): full path and name of the image file to be loaded

    Returns:
        np.ndarray: file as numpy ndarray
    """
    image =  iio.imread(filename,mode='F')

    return image


###
### show SEM image with boxes in various colours around each damage site
###
def show_boxes(image : np.ndarray, damage_sites : dict, box_size = [250,250],
               save_image = False, image_path : str = None) :
    """
    Shows an SEM image with colored boxes around identified damage sites.

    Args:
        image (np.ndarray): SEM image to be shown.
        damage_sites (dict): Python dictionary using the coordinates as key (x,y), and the label as value.
        box_size (list, optional): Size of the rectangle drawn around each centroid. Defaults to [250,250].
        save_image (bool, optional): Save the image with the boxes or not. Defaults to False.
        image_path (str, optional) : Full path and name of the output file to be saved.
    """
    logging.debug(f"show_boxes: Input image type: {type(image)}") 

    # Ensure image is a NumPy array of appropriate type for matplotlib
    if isinstance(image, Image.Image):
        image_to_plot = np.array(image.convert('L')) # Convert to grayscale NumPy array
        logging.debug("show_boxes: Converted PIL Image to grayscale NumPy array for plotting.")
    elif isinstance(image, np.ndarray):
        if image.ndim == 3 and image.shape[2] in [3,4]: # RGB or RGBA NumPy array
            image_to_plot = np.mean(image, axis=2).astype(image.dtype) # Convert to grayscale
            logging.debug("show_boxes: Converted multi-channel NumPy array to grayscale for plotting.")
        else: # Assume grayscale already
            image_to_plot = image
            logging.debug("show_boxes: Image is already a grayscale NumPy array.")
    else:
        logging.error("show_boxes: Unsupported image format received.")
        image_to_plot = np.zeros((100,100), dtype=np.uint8) # Fallback to black image


    _, ax = plt.subplots(1)
    ax.imshow(image_to_plot, cmap='gray')  # show image on correct axis
    ax.set_xticks([])
    ax.set_yticks([])

    for key, label in damage_sites.items():
        position = [key[0], key[1]] # Assuming key[0] is y (row) and key[1] is x (column)
        
        edgecolor = {
            'Inclusion': 'b',
            'Interface': 'g',
            'Martensite': 'r',
            'Notch': 'y',
            'Shadowing': 'm',
            'Not Classified': 'k' # Added Not Classified for completeness
        }.get(label, 'k')  # default: black

        # Ensure box_size elements are floats for division
        half_box_w = box_size[1] / 2.0
        half_box_h = box_size[0] / 2.0

        # x-coordinate of the bottom-left corner
        rect_x = position[1] - half_box_w
        # y-coordinate of the bottom-left corner (matplotlib origin is bottom-left)
        rect_y = position[0] - half_box_h

        rect = patches.Rectangle((rect_x, rect_y),
                                 box_size[1], box_size[0],
                                 linewidth=1, edgecolor=edgecolor, facecolor='none')
        ax.add_patch(rect)

    legend_elements = [
        Line2D([0], [0], color='b', lw=4, label='Inclusion'),
        Line2D([0], [0], color='g', lw=4, label='Interface'),
        Line2D([0], [0], color='r', lw=4, label='Martensite'),
        Line2D([0], [0], color='y', lw=4, label='Notch'),
        Line2D([0], [0], color='m', lw=4, label='Shadow'),
        Line2D([0], [0], color='k', lw=4, label='Not Classified')
    ]
    ax.legend(handles=legend_elements, bbox_to_anchor=(1.04, 1), loc="upper left")

    fig = ax.figure
    fig.tight_layout(pad=0)

    if save_image and image_path:
        fig.savefig(image_path, dpi=1200, bbox_inches='tight')

    canvas = fig.canvas
    canvas.draw()

    data = np.frombuffer(canvas.buffer_rgba(), dtype=np.uint8).reshape(
        canvas.get_width_height()[::-1] + (4,))
    data = data[:, :, :3]  # RGB only

    plt.close(fig)

    return data

##
## orig
##

# ###
# ### cut out small images from panorama, append colour information
# ###
# def prepare_classifier_input(panorama, centroids: list, window_size=[250, 250]) -> list: # Removed np.ndarray type hint for panorama
#     """
#     Extracts square image patches centered at each given centroid from a grayscale panoramic SEM image.
    
#     Each extracted patch is resized to the specified window size and converted into a 3-channel (RGB-like)
#     normalized image suitable for use with classification neural networks that expect color input.

#     Parameters
#     ----------
#     panorama : PIL.Image.Image or np.ndarray
#         Input SEM image. Should be a 2D array (H, W) or a 3D array (H, W, 1) representing grayscale data,
#         or a PIL Image object.
    
#     centroids : list of [int, int]
#         List of (y, x) coordinates marking the centers of regions of interest. These are typically damage sites
#         identified in preprocessing (e.g., clustering).

#     window_size : list of int, optional
#         Size [height, width] of each extracted image patch. Defaults to [250, 250].

#     Returns
#     -------
#     list of np.ndarray
#         List of extracted and normalized 3-channel image patches, each with shape (height, width, 3). Only
#         centroids that allow full window extraction within image bounds are used.
#     """
#     logging.info(f"prepare_classifier_input: Input panorama type: {type(panorama)}") # Added logging

#     # --- MINIMAL FIX START ---
#     # Convert PIL Image to NumPy array if necessary
#     if isinstance(panorama, Image.Image):
#         # Convert to grayscale NumPy array as your original code expects this structure for processing
#         if panorama.mode == 'RGB':
#             panorama_array = np.array(panorama.convert('L'))
#             logging.info("prepare_classifier_input: Converted RGB PIL Image to grayscale NumPy array.")
#         else:
#             panorama_array = np.array(panorama)
#             logging.info("prepare_classifier_input: Converted PIL Image to grayscale NumPy array.")
#     elif isinstance(panorama, np.ndarray):
#         # Ensure it's treated as a grayscale array for consistency with original logic
#         if panorama.ndim == 3 and panorama.shape[2] in [3, 4]: # RGB or RGBA NumPy array
#             panorama_array = np.mean(panorama, axis=2).astype(panorama.dtype) # Convert to grayscale
#             logging.info("prepare_classifier_input: Converted multi-channel NumPy array to grayscale.")
#         else:
#             panorama_array = panorama # Assume it's already grayscale 2D or (H,W,1)
#             logging.info("prepare_classifier_input: Panorama is already a suitable NumPy array.")
#     else:
#         logging.error("prepare_classifier_input: Unsupported panorama format received. Expected PIL Image or NumPy array.")
#         raise ValueError("Unsupported panorama format for classifier input.")

#     # Now, ensure panorama_array has a channel dimension if it's 2D for consistency
#     if panorama_array.ndim == 2:
#         panorama_array = np.expand_dims(panorama_array, axis=-1)  # (H, W, 1)
#         logging.info("prepare_classifier_input: Expanded 2D panorama to 3D (H,W,1).")
#     # --- MINIMAL FIX END ---

#     H, W, _ = panorama_array.shape # Use panorama_array here
#     win_h, win_w = window_size
#     images = []

#     for (cy, cx) in centroids:
#         # Ensure coordinates are integers
#         cy, cx = int(round(cy)), int(round(cx))

#         x1 = int(cx - win_w / 2)
#         y1 = int(cy - win_h / 2)
#         x2 = x1 + win_w
#         y2 = y1 + win_h

#         # Skip if patch would go out of bounds
#         if x1 < 0 or y1 < 0 or x2 > W or y2 > H:
#             logging.warning(f"prepare_classifier_input: Skipping centroid ({cy},{cx}) as patch is out of bounds.") # Added warning
#             continue

#         # Extract and normalize patch
#         patch = panorama_array[y1:y2, x1:x2, 0].astype(np.float32) # Use panorama_array
#         patch = patch * 2. / 255. - 1. # Keep your original normalization

#         # Replicate grayscale channel to simulate RGB
#         patch_color = np.repeat(patch[:, :, np.newaxis], 3, axis=2)
#         images.append(patch_color)

#     return images

###
### refactored
###
import numpy as np
from PIL import Image
import logging
from typing import List, Union, Tuple

def prepare_classifier_input(
    panorama: Union[Image.Image, np.ndarray], 
    centroids: List[Tuple[int, int]], 
    window_size: List[int] = [250, 250]
) -> List[np.ndarray]:
    """
    Extracts square image patches centered at each given centroid from a grayscale panoramic SEM image.
    
    Each extracted patch is resized to the specified window size and converted into a 3-channel (RGB-like)
    normalized image suitable for use with classification neural networks that expect color input.

    Parameters
    ----------
    panorama : PIL.Image.Image or np.ndarray
        Input SEM image. Should be a 2D array (H, W) or a 3D array (H, W, 1) representing grayscale data,
        or a PIL Image object.
    
    centroids : list of [int, int]
        List of (y, x) coordinates marking the centers of regions of interest. These are typically damage sites
        identified in preprocessing (e.g., clustering).

    window_size : list of int, optional
        Size [height, width] of each extracted image patch. Defaults to [250, 250].

    Returns
    -------
    list of np.ndarray
        List of extracted and normalized 3-channel image patches, each with shape (height, width, 3). Only
        centroids that allow full window extraction within image bounds are used.
    """
    logging.debug(f"prepare_classifier_input: Input panorama type: {type(panorama)}")

    # Convert input to standardized NumPy array format
    panorama_array = _convert_to_grayscale_array(panorama)
    
    # Ensure we have the correct dimensions
    if panorama_array.ndim == 2:
        H, W = panorama_array.shape
        logging.debug("prepare_classifier_input: Working with 2D grayscale array.")
    elif panorama_array.ndim == 3:
        H, W, C = panorama_array.shape
        if C == 1:
            # Squeeze the single channel dimension for easier processing
            panorama_array = panorama_array.squeeze(axis=2)
            H, W = panorama_array.shape
            logging.debug("prepare_classifier_input: Squeezed single channel dimension.")
        else:
            logging.error(f"prepare_classifier_input: Unexpected number of channels: {C}")
            raise ValueError(f"Expected 1 channel, got {C}")
    else:
        logging.error(f"prepare_classifier_input: Unexpected array dimensions: {panorama_array.ndim}")
        raise ValueError(f"Expected 2D or 3D array, got {panorama_array.ndim}D")

    win_h, win_w = window_size
    images = []
    
    logging.info(f"prepare_classifier_input: Image dimensions: {H}x{W}, Window size: {win_h}x{win_w}")
    logging.info(f"prepare_classifier_input: Processing {len(centroids)} centroids")

    for i, (cy, cx) in enumerate(centroids):
        # Ensure coordinates are integers
        cy, cx = int(round(cy)), int(round(cx))
        
        # Calculate patch boundaries
        half_h, half_w = win_h // 2, win_w // 2
        y1 = cy - half_h
        y2 = y1 + win_h
        x1 = cx - half_w
        x2 = x1 + win_w

        # Check bounds more explicitly
        if y1 < 0 or x1 < 0 or y2 > H or x2 > W:
            logging.warning(
                f"prepare_classifier_input: Skipping centroid {i+1}/{len(centroids)} "
                f"at ({cy},{cx}) - patch bounds ({y1}:{y2}, {x1}:{x2}) exceed image bounds (0:{H}, 0:{W})"
            )
            continue

        try:
            # Extract patch with explicit bounds checking
            patch = panorama_array[y1:y2, x1:x2].astype(np.float32)
            
            # Verify patch dimensions
            if patch.shape != (win_h, win_w):
                logging.warning(
                    f"prepare_classifier_input: Patch {i+1} has unexpected shape {patch.shape}, "
                    f"expected ({win_h}, {win_w}). Skipping."
                )
                continue
            
            # Normalize patch: [0, 255] -> [-1, 1]
            patch_normalized = (patch * 2.0 / 255.0) - 1.0
            
            # Convert to 3-channel RGB-like format
            patch_rgb = np.stack([patch_normalized] * 3, axis=2)
            
            images.append(patch_rgb)
            logging.debug(f"prepare_classifier_input: Successfully processed centroid {i+1} at ({cy},{cx})")
            
        except Exception as e:
            logging.error(
                f"prepare_classifier_input: Error processing centroid {i+1} at ({cy},{cx}): {e}"
            )
            continue

    logging.info(f"prepare_classifier_input: Successfully extracted {len(images)} patches from {len(centroids)} centroids")
    
    # Add diagnostic information about the output
    if images:
        sample_shape = images[0].shape
        sample_dtype = images[0].dtype
        sample_min = images[0].min()
        sample_max = images[0].max()
        logging.info(f"prepare_classifier_input: Output patches - Shape: {sample_shape}, Dtype: {sample_dtype}, Range: [{sample_min:.3f}, {sample_max:.3f}]")
        
        # Verify all patches have consistent shapes
        shapes = [img.shape for img in images]
        if not all(shape == sample_shape for shape in shapes):
            logging.warning("prepare_classifier_input: Inconsistent patch shapes detected!")
            for i, shape in enumerate(shapes):
                if shape != sample_shape:
                    logging.warning(f"  Patch {i}: {shape} (expected {sample_shape})")
    else:
        logging.warning("prepare_classifier_input: No valid patches were extracted!")
    
    return images


def _convert_to_grayscale_array(panorama: Union[Image.Image, np.ndarray]) -> np.ndarray:
    """
    Helper function to convert various input formats to a standardized grayscale NumPy array.
    
    Parameters
    ----------
    panorama : PIL.Image.Image or np.ndarray
        Input image in various formats
        
    Returns
    -------
    np.ndarray
        Standardized grayscale array
    """
    if isinstance(panorama, Image.Image):
        if panorama.mode in ['RGB', 'RGBA']:
            # Convert to grayscale
            panorama_array = np.array(panorama.convert('L'))
            logging.debug("_convert_to_grayscale_array: Converted RGB/RGBA PIL Image to grayscale.")
        elif panorama.mode == 'L':
            panorama_array = np.array(panorama)
            logging.debug("_convert_to_grayscale_array: Converted grayscale PIL Image to NumPy array.")
        else:
            # Handle other modes by converting to grayscale
            panorama_array = np.array(panorama.convert('L'))
            logging.debug(f"_convert_to_grayscale_array: Converted PIL Image mode '{panorama.mode}' to grayscale.")
            
    elif isinstance(panorama, np.ndarray):
        if panorama.ndim == 2:
            # Already grayscale
            panorama_array = panorama.copy()
            logging.debug("_convert_to_grayscale_array: Using existing 2D grayscale array.")
        elif panorama.ndim == 3:
            if panorama.shape[2] in [3, 4]:  # RGB or RGBA
                # Convert to grayscale using luminance weights
                if panorama.shape[2] == 3:  # RGB
                    panorama_array = np.dot(panorama, [0.299, 0.587, 0.114]).astype(panorama.dtype)
                else:  # RGBA
                    panorama_array = np.dot(panorama[:, :, :3], [0.299, 0.587, 0.114]).astype(panorama.dtype)
                logging.debug("_convert_to_grayscale_array: Converted multi-channel NumPy array to grayscale using luminance weights.")
            elif panorama.shape[2] == 1:
                # Already single channel
                panorama_array = panorama.copy()
                logging.debug("_convert_to_grayscale_array: Using existing single-channel array.")
            else:
                raise ValueError(f"Unsupported number of channels: {panorama.shape[2]}")
        else:
            raise ValueError(f"Unsupported array dimensions: {panorama.ndim}")
    else:
        raise ValueError(f"Unsupported panorama type: {type(panorama)}")
    
    return panorama_array


##
##  debug
## 
import numpy as np
import logging
from typing import List, Any

def debug_classification_input(patches: List[np.ndarray], model: Any = None) -> None:
    """
    Debug function to help identify issues in the classification pipeline.
    Call this right before your classification step.
    
    Parameters
    ----------
    patches : List[np.ndarray]
        List of image patches from prepare_classifier_input
    model : Any, optional
        Your classification model (for additional debugging)
    """
    logging.info("=== CLASSIFICATION DEBUG INFO ===")
    logging.info(f"Number of patches: {len(patches)}")
    
    if not patches:
        logging.error("No patches provided for classification!")
        return
    
    for i, patch in enumerate(patches):
        logging.info(f"Patch {i}:")
        logging.info(f"  Shape: {patch.shape}")
        logging.info(f"  Dtype: {patch.dtype}")
        logging.info(f"  Range: [{patch.min():.3f}, {patch.max():.3f}]")
        logging.info(f"  Memory layout: {patch.flags}")
        
        # Check for common issues
        if np.isnan(patch).any():
            logging.warning(f"  Contains NaN values: {np.isnan(patch).sum()}")
        if np.isinf(patch).any():
            logging.warning(f"  Contains infinite values: {np.isinf(patch).sum()}")
        
        # Check if patch is contiguous (some models require this)
        if not patch.flags.c_contiguous:
            logging.warning(f"  Patch {i} is not C-contiguous")
    
    # Test conversion to common formats
    try:
        patches_array = np.array(patches)
        logging.info(f"Stacked array shape: {patches_array.shape}")
        logging.info(f"Stacked array dtype: {patches_array.dtype}")
    except Exception as e:
        logging.error(f"Failed to stack patches into array: {e}")
    
    # Test batch preparation (common source of slice errors)
    try:
        if len(patches) > 0:
            # Common preprocessing steps that might cause issues
            test_batch = np.stack(patches, axis=0)  # Shape: (batch_size, height, width, channels)
            logging.info(f"Test batch shape: {test_batch.shape}")
            
            # Test various indexing operations that might cause slice errors
            test_slice = test_batch[0]  # Should work
            logging.info(f"Single item slice shape: {test_slice.shape}")
            
            test_batch_slice = test_batch[:]  # Should work
            logging.info(f"Full batch slice shape: {test_batch_slice.shape}")
            
    except Exception as e:
        logging.error(f"Error during batch preparation testing: {e}")
        logging.error(f"Error type: {type(e)}")
        import traceback
        logging.error(f"Traceback: {traceback.format_exc()}")
    
    logging.info("=== END CLASSIFICATION DEBUG ===")


def safe_classify_patches(patches: List[np.ndarray], classify_func, **kwargs) -> Any:
    """
    Wrapper function to safely run classification with better error handling.
    
    Parameters
    ----------
    patches : List[np.ndarray]
        List of image patches
    classify_func : callable
        Your classification function
    **kwargs
        Additional arguments for classify_func
        
    Returns
    -------
    Any
        Classification results or None if error occurred
    """
    try:
        logging.debug("Starting safe classification...")
        
        # Debug the input
        debug_classification_input(patches)
        
        # Ensure patches are properly formatted
        if not patches:
            logging.error("No patches to classify")
            return None
        
        # Make sure all patches are contiguous arrays
        patches_clean = []
        for i, patch in enumerate(patches):
            if not patch.flags.c_contiguous:
                patch_clean = np.ascontiguousarray(patch)
                logging.debug(f"Made patch {i} contiguous")
            else:
                patch_clean = patch
            patches_clean.append(patch_clean)
        
        # Call the actual classification function
        logging.debug("Calling classification function...")
        result = classify_func(patches_clean, **kwargs)
        logging.debug("Classification completed successfully")
        
        return result
        
    except Exception as e:
        logging.error(f"Error in safe_classify_patches: {e}")
        logging.error(f"Error type: {type(e)}")
        import traceback
        logging.error(f"Full traceback: {traceback.format_exc()}")
        return None


# Example usage function
def example_usage():
    """
    Example of how to use the debug functions in your pipeline
    """
    # Your existing code that calls prepare_classifier_input
    # patches = prepare_classifier_input(panorama, centroids, window_size)
    
    # Add debugging before classification
    # debug_classification_input(patches)
    
    # Use safe wrapper for classification
    # results = safe_classify_patches(patches, your_classify_function, model=your_model)
    
    pass


########################################
##
##
########################################
def extract_predictions_from_tfsm(model_output):
    """
    Helper function to extract predictions from TFSMLayer output.
    TFSMLayer often returns a dictionary with multiple outputs.
    """
    logging.debug(f"Model output type: {type(model_output)}")
    logging.debug(f"Model output keys: {model_output.keys() if isinstance(model_output, dict) else 'Not a dict'}")
    
    if isinstance(model_output, dict):
        # Try common output key names
        possible_keys = ['output', 'predictions', 'dense', 'logits', 'probabilities']
        
        # First, log all available keys
        available_keys = list(model_output.keys())
        logging.debug(f"Available output keys: {available_keys}")
        
        # Try to find the right output
        for key in possible_keys:
            if key in model_output:
                logging.debug(f"Using output key: {key}")
                return model_output[key].numpy()
        
        # If no standard key found, use the first available key
        if available_keys:
            first_key = available_keys[0]
            logging.debug(f"Using first available key: {first_key}")
            return model_output[first_key].numpy()
        else:
            raise ValueError("No output keys found in model response")
    else:
        # If it's not a dictionary, assume it's already the tensor we need
        logging.debug("Model output is not a dictionary, using directly")
        return model_output.numpy() if hasattr(model_output, 'numpy') else np.array(model_output)