Spaces:
Build error
Build error
| import os | |
| import hashlib | |
| from PIL import Image | |
| import numpy as np | |
| import tensorflow as tf | |
| from tensorflow.keras.preprocessing.image import ImageDataGenerator | |
| from sklearn.model_selection import StratifiedShuffleSplit | |
| import pandas as pd | |
| import io | |
| class DataProcessor: | |
| def __init__(self): | |
| self.input_shape = (64, 64, 3) # Default input shape for EuroSAT | |
| def preprocess_for_inference(self, image): | |
| """Preprocess a single image for model inference""" | |
| if isinstance(image, bytes): | |
| # Convert bytes to PIL Image | |
| image = Image.open(io.BytesIO(image)) | |
| # Convert RGBA to RGB if necessary | |
| if image.mode == 'RGBA': | |
| image = image.convert('RGB') | |
| # Convert PIL Image to numpy array | |
| image = np.array(image) | |
| # Convert to float32 and normalize | |
| image = image.astype(np.float32) / 255.0 | |
| # Resize image | |
| image = tf.image.resize(image, (self.input_shape[0], self.input_shape[1])) | |
| # Add batch dimension | |
| image = tf.expand_dims(image, 0) | |
| return image | |
| else: | |
| raise ValueError("Input must be bytes (image file content)") | |
| def check_image_size(image_path): | |
| """Check dimensions of an image file""" | |
| with Image.open(image_path) as img: | |
| return img.size | |
| def check_image_dimensions(dataset_path): | |
| """Check dimensions of all images in dataset""" | |
| all_dimensions = set() | |
| for folder in os.listdir(dataset_path): | |
| class_path = os.path.join(dataset_path, folder) | |
| if os.path.isdir(class_path): | |
| for image_name in os.listdir(class_path): | |
| image_path = os.path.join(class_path, image_name) | |
| width, height = DataProcessor.check_image_size(image_path) | |
| all_dimensions.add((width, height)) | |
| return all_dimensions | |
| def get_data_generators(): | |
| """Get data generators for training and validation""" | |
| train_gen = ImageDataGenerator( | |
| rescale=1./255, | |
| rotation_range=60, | |
| width_shift_range=0.2, | |
| height_shift_range=0.2, | |
| shear_range=0.2, | |
| zoom_range=0.2, | |
| horizontal_flip=True, | |
| vertical_flip=True | |
| ) | |
| test_gen = ImageDataGenerator(rescale=1./255) | |
| return train_gen, test_gen | |
| def get_image_hash(image_path): | |
| """Calculate MD5 hash of an image file""" | |
| with open(image_path, "rb") as f: | |
| return hashlib.md5(f.read()).hexdigest() | |
| def check_duplicates(dataset_path): | |
| """Find duplicate images in dataset""" | |
| seen_hashes = set() | |
| duplicates = [] | |
| for folder in os.listdir(dataset_path): | |
| class_path = os.path.join(dataset_path, folder) | |
| if os.path.isdir(class_path): | |
| for image_name in os.listdir(class_path): | |
| image_path = os.path.join(class_path, image_name) | |
| img_hash = DataProcessor.get_image_hash(image_path) | |
| if img_hash in seen_hashes: | |
| duplicates.append(image_path) | |
| else: | |
| seen_hashes.add(img_hash) | |
| return duplicates |