LandCoverClassification / data_processor.py
mahmoudalrefaey's picture
Upload 10 files
c3da629 verified
import os
import hashlib
from PIL import Image
import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.model_selection import StratifiedShuffleSplit
import pandas as pd
import io
class DataProcessor:
def __init__(self):
self.input_shape = (64, 64, 3) # Default input shape for EuroSAT
def preprocess_for_inference(self, image):
"""Preprocess a single image for model inference"""
if isinstance(image, bytes):
# Convert bytes to PIL Image
image = Image.open(io.BytesIO(image))
# Convert RGBA to RGB if necessary
if image.mode == 'RGBA':
image = image.convert('RGB')
# Convert PIL Image to numpy array
image = np.array(image)
# Convert to float32 and normalize
image = image.astype(np.float32) / 255.0
# Resize image
image = tf.image.resize(image, (self.input_shape[0], self.input_shape[1]))
# Add batch dimension
image = tf.expand_dims(image, 0)
return image
else:
raise ValueError("Input must be bytes (image file content)")
@staticmethod
def check_image_size(image_path):
"""Check dimensions of an image file"""
with Image.open(image_path) as img:
return img.size
@staticmethod
def check_image_dimensions(dataset_path):
"""Check dimensions of all images in dataset"""
all_dimensions = set()
for folder in os.listdir(dataset_path):
class_path = os.path.join(dataset_path, folder)
if os.path.isdir(class_path):
for image_name in os.listdir(class_path):
image_path = os.path.join(class_path, image_name)
width, height = DataProcessor.check_image_size(image_path)
all_dimensions.add((width, height))
return all_dimensions
@staticmethod
def get_data_generators():
"""Get data generators for training and validation"""
train_gen = ImageDataGenerator(
rescale=1./255,
rotation_range=60,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
vertical_flip=True
)
test_gen = ImageDataGenerator(rescale=1./255)
return train_gen, test_gen
@staticmethod
def get_image_hash(image_path):
"""Calculate MD5 hash of an image file"""
with open(image_path, "rb") as f:
return hashlib.md5(f.read()).hexdigest()
@staticmethod
def check_duplicates(dataset_path):
"""Find duplicate images in dataset"""
seen_hashes = set()
duplicates = []
for folder in os.listdir(dataset_path):
class_path = os.path.join(dataset_path, folder)
if os.path.isdir(class_path):
for image_name in os.listdir(class_path):
image_path = os.path.join(class_path, image_name)
img_hash = DataProcessor.get_image_hash(image_path)
if img_hash in seen_hashes:
duplicates.append(image_path)
else:
seen_hashes.add(img_hash)
return duplicates