Spaces:
Sleeping
Sleeping
| """Image loading and preprocessing utilities for Anton's pipeline.""" | |
| from pathlib import Path | |
| from typing import Union, Tuple, Optional, List | |
| import numpy as np | |
| from PIL import Image | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class ImageLoader: | |
| """Handles image loading and preprocessing for microscopy analysis.""" | |
| def __init__(self): | |
| """Initialize ImageLoader.""" | |
| self.current_image = None | |
| self.current_image_path = None | |
| self.metadata = {} | |
| def load(self, image_path: Union[str, Path]) -> np.ndarray: | |
| """Load image from path. | |
| Args: | |
| image_path: Path to the image file | |
| Returns: | |
| numpy array of the loaded image | |
| """ | |
| try: | |
| image_path = Path(image_path) | |
| if not image_path.exists(): | |
| raise FileNotFoundError(f"Image not found: {image_path}") | |
| # Load image using PIL (supports many formats including TIFF) | |
| pil_image = Image.open(image_path) | |
| # Convert to numpy array | |
| image_array = np.array(pil_image) | |
| # Store for later use | |
| self.current_image = image_array | |
| self.current_image_path = image_path | |
| # Extract basic metadata | |
| self.metadata = { | |
| 'shape': image_array.shape, | |
| 'dtype': str(image_array.dtype), | |
| 'path': str(image_path), | |
| 'format': pil_image.format, | |
| 'mode': pil_image.mode | |
| } | |
| logger.info(f"Loaded image: {image_path}, shape: {image_array.shape}") | |
| return image_array | |
| except Exception as e: | |
| logger.error(f"Failed to load image {image_path}: {e}") | |
| raise | |
| def preprocess(self, image: np.ndarray, normalize: bool = True, | |
| channels: Optional[List[int]] = None) -> np.ndarray: | |
| """Preprocess image for analysis. | |
| Args: | |
| image: Input image array | |
| normalize: Whether to normalize intensity values | |
| channels: Specific channels to extract (for multi-channel images) | |
| Returns: | |
| Preprocessed image array | |
| """ | |
| try: | |
| processed = image.copy() | |
| # Extract specific channels if requested | |
| if channels is not None and len(image.shape) > 2: | |
| if len(image.shape) == 3: | |
| # RGB/multi-channel image | |
| processed = processed[:, :, channels] | |
| elif len(image.shape) == 4: | |
| # Multi-channel with additional dimension | |
| processed = processed[:, :, :, channels] | |
| # Normalize if requested | |
| if normalize: | |
| processed = self._normalize_image(processed) | |
| return processed | |
| except Exception as e: | |
| logger.error(f"Failed to preprocess image: {e}") | |
| raise | |
| def _normalize_image(self, image: np.ndarray) -> np.ndarray: | |
| """Normalize image intensity values to 0-1 range.""" | |
| if image.dtype == np.uint8: | |
| return image.astype(np.float32) / 255.0 | |
| elif image.dtype == np.uint16: | |
| return image.astype(np.float32) / 65535.0 | |
| else: | |
| # For float images, normalize to 0-1 range | |
| min_val = image.min() | |
| max_val = image.max() | |
| if max_val > min_val: | |
| return (image - min_val) / (max_val - min_val) | |
| else: | |
| return image | |
| def extract_channel(self, image: np.ndarray, channel: int) -> np.ndarray: | |
| """Extract a specific channel from multi-channel image. | |
| Args: | |
| image: Multi-channel image array | |
| channel: Channel index to extract | |
| Returns: | |
| Single-channel image array | |
| """ | |
| try: | |
| if len(image.shape) == 2: | |
| # Grayscale image | |
| return image | |
| elif len(image.shape) == 3: | |
| # Multi-channel image | |
| if channel < image.shape[2]: | |
| return image[:, :, channel] | |
| else: | |
| raise ValueError(f"Channel {channel} not available in image with {image.shape[2]} channels") | |
| else: | |
| raise ValueError(f"Unsupported image shape: {image.shape}") | |
| except Exception as e: | |
| logger.error(f"Failed to extract channel {channel}: {e}") | |
| raise | |
| def convert_to_8bit(self, image: np.ndarray) -> np.ndarray: | |
| """Convert image to 8-bit for display/export. | |
| Args: | |
| image: Input image array | |
| Returns: | |
| 8-bit image array | |
| """ | |
| try: | |
| if image.dtype == np.uint8: | |
| return image | |
| # Normalize to 0-1 range first | |
| normalized = self._normalize_image(image) | |
| # Convert to 8-bit | |
| return (normalized * 255).astype(np.uint8) | |
| except Exception as e: | |
| logger.error(f"Failed to convert to 8-bit: {e}") | |
| raise | |
| def save_image(self, image: np.ndarray, output_path: Union[str, Path], | |
| format: str = 'PNG') -> None: | |
| """Save image to file. | |
| Args: | |
| image: Image array to save | |
| output_path: Output file path | |
| format: Image format (PNG, TIFF, etc.) | |
| """ | |
| try: | |
| output_path = Path(output_path) | |
| output_path.parent.mkdir(parents=True, exist_ok=True) | |
| # Convert to 8-bit if needed | |
| if image.dtype != np.uint8: | |
| image = self.convert_to_8bit(image) | |
| # Create PIL image and save | |
| pil_image = Image.fromarray(image) | |
| pil_image.save(output_path, format=format) | |
| logger.info(f"Saved image to: {output_path}") | |
| except Exception as e: | |
| logger.error(f"Failed to save image to {output_path}: {e}") | |
| raise | |
| def get_image_stats(self, image: Optional[np.ndarray] = None) -> dict: | |
| """Get basic statistics about the image. | |
| Args: | |
| image: Image array (uses current_image if None) | |
| Returns: | |
| Dictionary with image statistics | |
| """ | |
| if image is None: | |
| image = self.current_image | |
| if image is None: | |
| return {} | |
| try: | |
| stats = { | |
| 'shape': image.shape, | |
| 'dtype': str(image.dtype), | |
| 'min': float(image.min()), | |
| 'max': float(image.max()), | |
| 'mean': float(image.mean()), | |
| 'std': float(image.std()) | |
| } | |
| if len(image.shape) > 2: | |
| stats['channels'] = image.shape[2] if len(image.shape) == 3 else image.shape[-1] | |
| return stats | |
| except Exception as e: | |
| logger.error(f"Failed to compute image statistics: {e}") | |
| return {} | |
| def create_rgb_composite(self, channels: List[np.ndarray], | |
| colors: List[Tuple[float, float, float]] = None) -> np.ndarray: | |
| """Create RGB composite from multiple channels. | |
| Args: | |
| channels: List of single-channel images | |
| colors: List of RGB colors for each channel (default: R, G, B) | |
| Returns: | |
| RGB composite image | |
| """ | |
| try: | |
| if not channels: | |
| raise ValueError("No channels provided") | |
| # Default colors (R, G, B) | |
| if colors is None: | |
| colors = [(1, 0, 0), (0, 1, 0), (0, 0, 1)] | |
| # Ensure all channels have the same shape | |
| shape = channels[0].shape | |
| for i, ch in enumerate(channels): | |
| if ch.shape != shape: | |
| raise ValueError(f"Channel {i} shape {ch.shape} doesn't match expected {shape}") | |
| # Create RGB composite | |
| composite = np.zeros((*shape, 3), dtype=np.float32) | |
| for i, (channel, color) in enumerate(zip(channels, colors)): | |
| # Normalize channel | |
| norm_channel = self._normalize_image(channel) | |
| # Apply color | |
| for c in range(3): | |
| composite[:, :, c] += norm_channel * color[c] | |
| # Clip to valid range | |
| composite = np.clip(composite, 0, 1) | |
| return composite | |
| except Exception as e: | |
| logger.error(f"Failed to create RGB composite: {e}") | |
| raise |