Spaces:
Sleeping
Sleeping
File size: 3,908 Bytes
42a7d1b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
"""
Image processing utility for FoodViT
Handles image preprocessing and transformation for model inference
"""
import cv2
import numpy as np
import torch
from PIL import Image
import albumentations as A
from config import IMAGE_CONFIG
class ImageProcessor:
"""Class to handle image preprocessing and transformation"""
def __init__(self):
self.target_size = IMAGE_CONFIG["target_size"]
self.normalize_mean = IMAGE_CONFIG["normalize_mean"]
self.normalize_std = IMAGE_CONFIG["normalize_std"]
# Initialize transformations
self.normalize = A.Normalize(
mean=self.normalize_mean,
std=self.normalize_std
)
self.val_transform = A.Compose([
A.Resize(self.target_size[0], self.target_size[1]),
A.CenterCrop(self.target_size[0], self.target_size[1]),
self.normalize
])
def preprocess_image(self, image_path):
"""
Preprocess image for model inference
Args:
image_path: Path to the image file or PIL Image object
Returns:
torch.Tensor: Preprocessed image tensor
"""
try:
# Load image
if isinstance(image_path, str):
# Load from file path
image = cv2.imread(image_path)
if image is None:
raise ValueError(f"Could not load image from {image_path}")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
elif isinstance(image_path, Image.Image):
# Convert PIL Image to numpy array
image = np.array(image_path)
if len(image.shape) == 3 and image.shape[2] == 3:
# Already RGB
pass
else:
# Convert to RGB if needed
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
else:
raise ValueError("Unsupported image format")
# Apply transformations
transformed = self.val_transform(image=image)
processed_image = transformed['image']
# Convert to tensor and change format
tensor_image = torch.tensor(processed_image, dtype=torch.float32)
tensor_image = tensor_image.permute(2, 0, 1) # HWC to CHW
# Add batch dimension
tensor_image = tensor_image.unsqueeze(0)
return tensor_image
except Exception as e:
print(f"Error preprocessing image: {e}")
return None
def preprocess_pil_image(self, pil_image):
"""
Preprocess PIL Image for model inference
Args:
pil_image: PIL Image object
Returns:
torch.Tensor: Preprocessed image tensor
"""
return self.preprocess_image(pil_image)
def get_image_info(self, image_path):
"""
Get basic information about an image
Args:
image_path: Path to the image file
Returns:
dict: Image information
"""
try:
image = cv2.imread(image_path)
if image is None:
return None
return {
"height": image.shape[0],
"width": image.shape[1],
"channels": image.shape[2] if len(image.shape) == 3 else 1,
"dtype": str(image.dtype)
}
except Exception as e:
print(f"Error getting image info: {e}")
return None
# Global image processor instance
image_processor = ImageProcessor() |