""" Image processing utility for FoodViT Handles image preprocessing and transformation for model inference """ import cv2 import numpy as np import torch from PIL import Image import albumentations as A from config import IMAGE_CONFIG class ImageProcessor: """Class to handle image preprocessing and transformation""" def __init__(self): self.target_size = IMAGE_CONFIG["target_size"] self.normalize_mean = IMAGE_CONFIG["normalize_mean"] self.normalize_std = IMAGE_CONFIG["normalize_std"] # Initialize transformations self.normalize = A.Normalize( mean=self.normalize_mean, std=self.normalize_std ) self.val_transform = A.Compose([ A.Resize(self.target_size[0], self.target_size[1]), A.CenterCrop(self.target_size[0], self.target_size[1]), self.normalize ]) def preprocess_image(self, image_path): """ Preprocess image for model inference Args: image_path: Path to the image file or PIL Image object Returns: torch.Tensor: Preprocessed image tensor """ try: # Load image if isinstance(image_path, str): # Load from file path image = cv2.imread(image_path) if image is None: raise ValueError(f"Could not load image from {image_path}") image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) elif isinstance(image_path, Image.Image): # Convert PIL Image to numpy array image = np.array(image_path) if len(image.shape) == 3 and image.shape[2] == 3: # Already RGB pass else: # Convert to RGB if needed image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) else: raise ValueError("Unsupported image format") # Apply transformations transformed = self.val_transform(image=image) processed_image = transformed['image'] # Convert to tensor and change format tensor_image = torch.tensor(processed_image, dtype=torch.float32) tensor_image = tensor_image.permute(2, 0, 1) # HWC to CHW # Add batch dimension tensor_image = tensor_image.unsqueeze(0) return tensor_image except Exception as e: print(f"Error preprocessing image: {e}") return None def preprocess_pil_image(self, pil_image): """ Preprocess PIL Image for model inference Args: pil_image: PIL Image object Returns: torch.Tensor: Preprocessed image tensor """ return self.preprocess_image(pil_image) def get_image_info(self, image_path): """ Get basic information about an image Args: image_path: Path to the image file Returns: dict: Image information """ try: image = cv2.imread(image_path) if image is None: return None return { "height": image.shape[0], "width": image.shape[1], "channels": image.shape[2] if len(image.shape) == 3 else 1, "dtype": str(image.dtype) } except Exception as e: print(f"Error getting image info: {e}") return None # Global image processor instance image_processor = ImageProcessor()