Spaces:
Sleeping
Sleeping
| """ | |
| Utility functions for image preprocessing. | |
| Handles various input formats: bytes, base64, PIL images, etc. | |
| """ | |
| import io | |
| import base64 | |
| import numpy as np | |
| from PIL import Image | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| def preprocess_image_from_bytes(image_bytes: bytes) -> np.ndarray: | |
| """ | |
| Preprocess image from raw bytes. | |
| Args: | |
| image_bytes: Raw image bytes (PNG, JPG, etc.) | |
| Returns: | |
| Preprocessed numpy array of shape (1, 28, 28, 1) normalized to [0, 1] | |
| """ | |
| try: | |
| # Load image from bytes | |
| image = Image.open(io.BytesIO(image_bytes)) | |
| # Convert to grayscale | |
| image = image.convert('L') | |
| # Resize to 28x28 | |
| image = image.resize((28, 28), Image.Resampling.LANCZOS) | |
| # Convert to numpy array | |
| image_array = np.array(image, dtype=np.float32) | |
| # Normalize to [0, 1] | |
| image_array = image_array / 255.0 | |
| # Reshape to (1, 28, 28, 1) for model input | |
| image_array = image_array.reshape(1, 28, 28, 1) | |
| return image_array | |
| except Exception as e: | |
| logger.error(f"Error preprocessing image from bytes: {e}") | |
| raise ValueError(f"Failed to process image: {str(e)}") | |
| def preprocess_image_from_base64(base64_string: str) -> np.ndarray: | |
| """ | |
| Preprocess image from base64 encoded string. | |
| Args: | |
| base64_string: Base64 encoded image string (with or without data URI prefix) | |
| Returns: | |
| Preprocessed numpy array of shape (1, 28, 28, 1) normalized to [0, 1] | |
| """ | |
| try: | |
| # Remove data URI prefix if present (e.g., "data:image/png;base64,") | |
| if ',' in base64_string and base64_string.startswith('data:'): | |
| base64_string = base64_string.split(',', 1)[1] | |
| # Decode base64 to bytes | |
| image_bytes = base64.b64decode(base64_string) | |
| # Use the bytes preprocessing function | |
| return preprocess_image_from_bytes(image_bytes) | |
| except Exception as e: | |
| logger.error(f"Error preprocessing image from base64: {e}") | |
| raise ValueError(f"Failed to process base64 image: {str(e)}") | |
| def preprocess_image_from_array(image_array: np.ndarray) -> np.ndarray: | |
| """ | |
| Preprocess image from numpy array. | |
| Handles various input shapes and formats. | |
| Args: | |
| image_array: Numpy array representing an image | |
| Returns: | |
| Preprocessed numpy array of shape (1, 28, 28, 1) normalized to [0, 1] | |
| """ | |
| try: | |
| # Convert to float32 | |
| image_array = image_array.astype(np.float32) | |
| # Handle different input shapes | |
| if len(image_array.shape) == 4: # (batch, height, width, channels) | |
| # Take first image if batch | |
| image_array = image_array[0] | |
| if len(image_array.shape) == 3: # (height, width, channels) | |
| # If RGB, convert to grayscale | |
| if image_array.shape[2] == 3: | |
| # Simple RGB to grayscale conversion | |
| image_array = 0.299 * image_array[:, :, 0] + \ | |
| 0.587 * image_array[:, :, 1] + \ | |
| 0.114 * image_array[:, :, 2] | |
| elif image_array.shape[2] == 1: | |
| image_array = image_array.squeeze(-1) | |
| # Now image_array should be 2D (height, width) | |
| if len(image_array.shape) != 2: | |
| raise ValueError(f"Cannot process image with shape {image_array.shape}") | |
| # Resize if needed | |
| if image_array.shape != (28, 28): | |
| image_pil = Image.fromarray(image_array.astype(np.uint8)) | |
| image_pil = image_pil.resize((28, 28), Image.Resampling.LANCZOS) | |
| image_array = np.array(image_pil, dtype=np.float32) | |
| # Normalize to [0, 1] if not already | |
| if image_array.max() > 1.0: | |
| image_array = image_array / 255.0 | |
| # Reshape to (1, 28, 28, 1) | |
| image_array = image_array.reshape(1, 28, 28, 1) | |
| return image_array | |
| except Exception as e: | |
| logger.error(f"Error preprocessing image from array: {e}") | |
| raise ValueError(f"Failed to process image array: {str(e)}") | |
| def preprocess_stroke_data(strokes: list, canvas_size: int = 256) -> np.ndarray: | |
| """ | |
| Convert stroke data (list of coordinates) to a 28x28 image. | |
| Useful if VR application sends raw drawing coordinates. | |
| Args: | |
| strokes: List of strokes, where each stroke is a list of (x, y) coordinates | |
| Example: [[(x1, y1), (x2, y2), ...], [(x3, y3), ...]] | |
| canvas_size: Size of the virtual canvas (default: 256x256) | |
| Returns: | |
| Preprocessed numpy array of shape (1, 28, 28, 1) normalized to [0, 1] | |
| """ | |
| try: | |
| # Create a blank canvas | |
| canvas = np.zeros((canvas_size, canvas_size), dtype=np.uint8) | |
| # Draw strokes on canvas | |
| for stroke in strokes: | |
| if len(stroke) < 2: | |
| continue | |
| # Draw lines between consecutive points | |
| for i in range(len(stroke) - 1): | |
| x1, y1 = stroke[i] | |
| x2, y2 = stroke[i + 1] | |
| # Simple line drawing (Bresenham's algorithm would be better) | |
| # For now, use a simple approximation | |
| points = _interpolate_points(x1, y1, x2, y2) | |
| for x, y in points: | |
| if 0 <= x < canvas_size and 0 <= y < canvas_size: | |
| canvas[int(y), int(x)] = 255 | |
| # Convert canvas to PIL Image for resizing | |
| image = Image.fromarray(canvas) | |
| image = image.resize((28, 28), Image.Resampling.LANCZOS) | |
| # Convert to numpy array and normalize | |
| image_array = np.array(image, dtype=np.float32) / 255.0 | |
| # Reshape to (1, 28, 28, 1) | |
| image_array = image_array.reshape(1, 28, 28, 1) | |
| return image_array | |
| except Exception as e: | |
| logger.error(f"Error preprocessing stroke data: {e}") | |
| raise ValueError(f"Failed to process stroke data: {str(e)}") | |
| def _interpolate_points(x1: float, y1: float, x2: float, y2: float, num_points: int = 10) -> list: | |
| """ | |
| Interpolate points between two coordinates for smooth line drawing. | |
| Args: | |
| x1, y1: Start coordinates | |
| x2, y2: End coordinates | |
| num_points: Number of points to interpolate | |
| Returns: | |
| List of (x, y) coordinate tuples | |
| """ | |
| points = [] | |
| for i in range(num_points + 1): | |
| t = i / num_points | |
| x = x1 + t * (x2 - x1) | |
| y = y1 + t * (y2 - y1) | |
| points.append((x, y)) | |
| return points | |