PalmOil-Classification / utils /preprocessing.py
Mawube's picture
Use only image urls
ed89f9e unverified
import numpy as np
from torchvision.transforms import functional as TF
from torchvision.transforms import transforms
from PIL import Image
import cv2
import logging
def to_hsv(image):
"""Convert PIL image to HSV color space"""
image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2HSV)
return TF.to_pil_image(image)
# Define the preprocessing function for a single image
def preprocess_image(image):
logging.info("Preprocessing image")
# Load the image
img = Image.open(image)
# Ensure the image is in RGB format
if isinstance(img, Image.Image):
img = img.convert("RGB")
elif isinstance(img, (np.ndarray, np.generic)):
if img.shape[-1] == 1: # Grayscale image
img = np.stack([img, img, img], axis=-1)
elif img.shape[-1] != 3: # Not RGB or Grayscale
logging.error("Input image must be in RGB or Grayscale format")
raise ValueError("Input image must be in RGB or Grayscale format")
else:
logging.error("Unsupported image type")
raise ValueError("Unsupported image type")
# Define the transformations
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
# Apply the transformations to the image
preprocessed_img = transform(img)
# Add batch dimension and reorder dimensions to [1, 224, 224, 3]
preprocessed_img = preprocessed_img.unsqueeze(0).permute(0, 2, 3, 1)
return preprocessed_img