pskeshu's picture
minimal example
b69e9e7
"""Image loading and preprocessing utilities for Anton's pipeline."""
from pathlib import Path
from typing import Union, Tuple, Optional, List
import numpy as np
from PIL import Image
import logging
logger = logging.getLogger(__name__)
class ImageLoader:
"""Handles image loading and preprocessing for microscopy analysis."""
def __init__(self):
"""Initialize ImageLoader."""
self.current_image = None
self.current_image_path = None
self.metadata = {}
def load(self, image_path: Union[str, Path]) -> np.ndarray:
"""Load image from path.
Args:
image_path: Path to the image file
Returns:
numpy array of the loaded image
"""
try:
image_path = Path(image_path)
if not image_path.exists():
raise FileNotFoundError(f"Image not found: {image_path}")
# Load image using PIL (supports many formats including TIFF)
pil_image = Image.open(image_path)
# Convert to numpy array
image_array = np.array(pil_image)
# Store for later use
self.current_image = image_array
self.current_image_path = image_path
# Extract basic metadata
self.metadata = {
'shape': image_array.shape,
'dtype': str(image_array.dtype),
'path': str(image_path),
'format': pil_image.format,
'mode': pil_image.mode
}
logger.info(f"Loaded image: {image_path}, shape: {image_array.shape}")
return image_array
except Exception as e:
logger.error(f"Failed to load image {image_path}: {e}")
raise
def preprocess(self, image: np.ndarray, normalize: bool = True,
channels: Optional[List[int]] = None) -> np.ndarray:
"""Preprocess image for analysis.
Args:
image: Input image array
normalize: Whether to normalize intensity values
channels: Specific channels to extract (for multi-channel images)
Returns:
Preprocessed image array
"""
try:
processed = image.copy()
# Extract specific channels if requested
if channels is not None and len(image.shape) > 2:
if len(image.shape) == 3:
# RGB/multi-channel image
processed = processed[:, :, channels]
elif len(image.shape) == 4:
# Multi-channel with additional dimension
processed = processed[:, :, :, channels]
# Normalize if requested
if normalize:
processed = self._normalize_image(processed)
return processed
except Exception as e:
logger.error(f"Failed to preprocess image: {e}")
raise
def _normalize_image(self, image: np.ndarray) -> np.ndarray:
"""Normalize image intensity values to 0-1 range."""
if image.dtype == np.uint8:
return image.astype(np.float32) / 255.0
elif image.dtype == np.uint16:
return image.astype(np.float32) / 65535.0
else:
# For float images, normalize to 0-1 range
min_val = image.min()
max_val = image.max()
if max_val > min_val:
return (image - min_val) / (max_val - min_val)
else:
return image
def extract_channel(self, image: np.ndarray, channel: int) -> np.ndarray:
"""Extract a specific channel from multi-channel image.
Args:
image: Multi-channel image array
channel: Channel index to extract
Returns:
Single-channel image array
"""
try:
if len(image.shape) == 2:
# Grayscale image
return image
elif len(image.shape) == 3:
# Multi-channel image
if channel < image.shape[2]:
return image[:, :, channel]
else:
raise ValueError(f"Channel {channel} not available in image with {image.shape[2]} channels")
else:
raise ValueError(f"Unsupported image shape: {image.shape}")
except Exception as e:
logger.error(f"Failed to extract channel {channel}: {e}")
raise
def convert_to_8bit(self, image: np.ndarray) -> np.ndarray:
"""Convert image to 8-bit for display/export.
Args:
image: Input image array
Returns:
8-bit image array
"""
try:
if image.dtype == np.uint8:
return image
# Normalize to 0-1 range first
normalized = self._normalize_image(image)
# Convert to 8-bit
return (normalized * 255).astype(np.uint8)
except Exception as e:
logger.error(f"Failed to convert to 8-bit: {e}")
raise
def save_image(self, image: np.ndarray, output_path: Union[str, Path],
format: str = 'PNG') -> None:
"""Save image to file.
Args:
image: Image array to save
output_path: Output file path
format: Image format (PNG, TIFF, etc.)
"""
try:
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
# Convert to 8-bit if needed
if image.dtype != np.uint8:
image = self.convert_to_8bit(image)
# Create PIL image and save
pil_image = Image.fromarray(image)
pil_image.save(output_path, format=format)
logger.info(f"Saved image to: {output_path}")
except Exception as e:
logger.error(f"Failed to save image to {output_path}: {e}")
raise
def get_image_stats(self, image: Optional[np.ndarray] = None) -> dict:
"""Get basic statistics about the image.
Args:
image: Image array (uses current_image if None)
Returns:
Dictionary with image statistics
"""
if image is None:
image = self.current_image
if image is None:
return {}
try:
stats = {
'shape': image.shape,
'dtype': str(image.dtype),
'min': float(image.min()),
'max': float(image.max()),
'mean': float(image.mean()),
'std': float(image.std())
}
if len(image.shape) > 2:
stats['channels'] = image.shape[2] if len(image.shape) == 3 else image.shape[-1]
return stats
except Exception as e:
logger.error(f"Failed to compute image statistics: {e}")
return {}
def create_rgb_composite(self, channels: List[np.ndarray],
colors: List[Tuple[float, float, float]] = None) -> np.ndarray:
"""Create RGB composite from multiple channels.
Args:
channels: List of single-channel images
colors: List of RGB colors for each channel (default: R, G, B)
Returns:
RGB composite image
"""
try:
if not channels:
raise ValueError("No channels provided")
# Default colors (R, G, B)
if colors is None:
colors = [(1, 0, 0), (0, 1, 0), (0, 0, 1)]
# Ensure all channels have the same shape
shape = channels[0].shape
for i, ch in enumerate(channels):
if ch.shape != shape:
raise ValueError(f"Channel {i} shape {ch.shape} doesn't match expected {shape}")
# Create RGB composite
composite = np.zeros((*shape, 3), dtype=np.float32)
for i, (channel, color) in enumerate(zip(channels, colors)):
# Normalize channel
norm_channel = self._normalize_image(channel)
# Apply color
for c in range(3):
composite[:, :, c] += norm_channel * color[c]
# Clip to valid range
composite = np.clip(composite, 0, 1)
return composite
except Exception as e:
logger.error(f"Failed to create RGB composite: {e}")
raise