Spaces:
Sleeping
Sleeping
File size: 6,764 Bytes
d2a2955 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 |
"""
Utility functions for image preprocessing.
Handles various input formats: bytes, base64, PIL images, etc.
"""
import io
import base64
import numpy as np
from PIL import Image
import logging
logger = logging.getLogger(__name__)
def preprocess_image_from_bytes(image_bytes: bytes) -> np.ndarray:
"""
Preprocess image from raw bytes.
Args:
image_bytes: Raw image bytes (PNG, JPG, etc.)
Returns:
Preprocessed numpy array of shape (1, 28, 28, 1) normalized to [0, 1]
"""
try:
# Load image from bytes
image = Image.open(io.BytesIO(image_bytes))
# Convert to grayscale
image = image.convert('L')
# Resize to 28x28
image = image.resize((28, 28), Image.Resampling.LANCZOS)
# Convert to numpy array
image_array = np.array(image, dtype=np.float32)
# Normalize to [0, 1]
image_array = image_array / 255.0
# Reshape to (1, 28, 28, 1) for model input
image_array = image_array.reshape(1, 28, 28, 1)
return image_array
except Exception as e:
logger.error(f"Error preprocessing image from bytes: {e}")
raise ValueError(f"Failed to process image: {str(e)}")
def preprocess_image_from_base64(base64_string: str) -> np.ndarray:
"""
Preprocess image from base64 encoded string.
Args:
base64_string: Base64 encoded image string (with or without data URI prefix)
Returns:
Preprocessed numpy array of shape (1, 28, 28, 1) normalized to [0, 1]
"""
try:
# Remove data URI prefix if present (e.g., "data:image/png;base64,")
if ',' in base64_string and base64_string.startswith('data:'):
base64_string = base64_string.split(',', 1)[1]
# Decode base64 to bytes
image_bytes = base64.b64decode(base64_string)
# Use the bytes preprocessing function
return preprocess_image_from_bytes(image_bytes)
except Exception as e:
logger.error(f"Error preprocessing image from base64: {e}")
raise ValueError(f"Failed to process base64 image: {str(e)}")
def preprocess_image_from_array(image_array: np.ndarray) -> np.ndarray:
"""
Preprocess image from numpy array.
Handles various input shapes and formats.
Args:
image_array: Numpy array representing an image
Returns:
Preprocessed numpy array of shape (1, 28, 28, 1) normalized to [0, 1]
"""
try:
# Convert to float32
image_array = image_array.astype(np.float32)
# Handle different input shapes
if len(image_array.shape) == 4: # (batch, height, width, channels)
# Take first image if batch
image_array = image_array[0]
if len(image_array.shape) == 3: # (height, width, channels)
# If RGB, convert to grayscale
if image_array.shape[2] == 3:
# Simple RGB to grayscale conversion
image_array = 0.299 * image_array[:, :, 0] + \
0.587 * image_array[:, :, 1] + \
0.114 * image_array[:, :, 2]
elif image_array.shape[2] == 1:
image_array = image_array.squeeze(-1)
# Now image_array should be 2D (height, width)
if len(image_array.shape) != 2:
raise ValueError(f"Cannot process image with shape {image_array.shape}")
# Resize if needed
if image_array.shape != (28, 28):
image_pil = Image.fromarray(image_array.astype(np.uint8))
image_pil = image_pil.resize((28, 28), Image.Resampling.LANCZOS)
image_array = np.array(image_pil, dtype=np.float32)
# Normalize to [0, 1] if not already
if image_array.max() > 1.0:
image_array = image_array / 255.0
# Reshape to (1, 28, 28, 1)
image_array = image_array.reshape(1, 28, 28, 1)
return image_array
except Exception as e:
logger.error(f"Error preprocessing image from array: {e}")
raise ValueError(f"Failed to process image array: {str(e)}")
def preprocess_stroke_data(strokes: list, canvas_size: int = 256) -> np.ndarray:
"""
Convert stroke data (list of coordinates) to a 28x28 image.
Useful if VR application sends raw drawing coordinates.
Args:
strokes: List of strokes, where each stroke is a list of (x, y) coordinates
Example: [[(x1, y1), (x2, y2), ...], [(x3, y3), ...]]
canvas_size: Size of the virtual canvas (default: 256x256)
Returns:
Preprocessed numpy array of shape (1, 28, 28, 1) normalized to [0, 1]
"""
try:
# Create a blank canvas
canvas = np.zeros((canvas_size, canvas_size), dtype=np.uint8)
# Draw strokes on canvas
for stroke in strokes:
if len(stroke) < 2:
continue
# Draw lines between consecutive points
for i in range(len(stroke) - 1):
x1, y1 = stroke[i]
x2, y2 = stroke[i + 1]
# Simple line drawing (Bresenham's algorithm would be better)
# For now, use a simple approximation
points = _interpolate_points(x1, y1, x2, y2)
for x, y in points:
if 0 <= x < canvas_size and 0 <= y < canvas_size:
canvas[int(y), int(x)] = 255
# Convert canvas to PIL Image for resizing
image = Image.fromarray(canvas)
image = image.resize((28, 28), Image.Resampling.LANCZOS)
# Convert to numpy array and normalize
image_array = np.array(image, dtype=np.float32) / 255.0
# Reshape to (1, 28, 28, 1)
image_array = image_array.reshape(1, 28, 28, 1)
return image_array
except Exception as e:
logger.error(f"Error preprocessing stroke data: {e}")
raise ValueError(f"Failed to process stroke data: {str(e)}")
def _interpolate_points(x1: float, y1: float, x2: float, y2: float, num_points: int = 10) -> list:
"""
Interpolate points between two coordinates for smooth line drawing.
Args:
x1, y1: Start coordinates
x2, y2: End coordinates
num_points: Number of points to interpolate
Returns:
List of (x, y) coordinate tuples
"""
points = []
for i in range(num_points + 1):
t = i / num_points
x = x1 + t * (x2 - x1)
y = y1 + t * (y2 - y1)
points.append((x, y))
return points
|