File size: 5,099 Bytes
ecb5b6d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
"""
Custom DINOv2 model that includes pad_to_square preprocessing in the forward pass.
This allows inference endpoints to automatically apply correct preprocessing.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import Dinov2ForImageClassification
class FontClassifierWithPreprocessing(Dinov2ForImageClassification):
"""
DINOv2 model that automatically applies pad_to_square preprocessing.
This model can be deployed to Inference Endpoints and will automatically
handle preprocessing in the forward pass, so clients can send raw images.
"""
def __init__(self, config):
super().__init__(config)
# Store preprocessing parameters
self.register_buffer('image_mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
self.register_buffer('image_std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
self.target_size = 224
def pad_to_square_tensor(self, images):
"""
Pad batch of images to square preserving aspect ratio.
Args:
images: Tensor of shape (B, C, H, W)
Returns:
Tensor of shape (B, C, max_size, max_size)
"""
B, C, H, W = images.shape
max_size = max(H, W)
if H == W == max_size:
return images # Already square
# Calculate padding
pad_h = max_size - H
pad_w = max_size - W
pad_top = pad_h // 2
pad_bottom = pad_h - pad_top
pad_left = pad_w // 2
pad_right = pad_w - pad_left
# Apply padding (left, right, top, bottom)
padded = F.pad(images, (pad_left, pad_right, pad_top, pad_bottom), value=0)
return padded
def preprocess_images(self, pixel_values):
"""
Apply full preprocessing pipeline to raw or partially processed images.
Args:
pixel_values: Tensor of shape (B, C, H, W)
Returns:
Preprocessed tensor ready for DINOv2
"""
# Ensure we have a batch dimension
if pixel_values.dim() == 3:
pixel_values = pixel_values.unsqueeze(0)
# Convert to float if needed
if pixel_values.dtype != torch.float32:
pixel_values = pixel_values.float()
# Normalize to [0, 1] if values are in [0, 255]
if pixel_values.max() > 1.0:
pixel_values = pixel_values / 255.0
# Apply pad_to_square
pixel_values = self.pad_to_square_tensor(pixel_values)
# Resize to target size
if pixel_values.shape[-1] != self.target_size or pixel_values.shape[-2] != self.target_size:
pixel_values = F.interpolate(
pixel_values,
size=(self.target_size, self.target_size),
mode='bilinear',
align_corners=False
)
# Apply ImageNet normalization
pixel_values = (pixel_values - self.image_mean) / self.image_std
return pixel_values
def forward(self, pixel_values=None, labels=None, **kwargs):
"""
Forward pass with automatic preprocessing.
Args:
pixel_values: Raw or preprocessed images
labels: Optional labels for training
"""
if pixel_values is None:
raise ValueError("pixel_values must be provided")
# Apply preprocessing automatically
processed_pixel_values = self.preprocess_images(pixel_values)
# Call parent forward with preprocessed images
return super().forward(pixel_values=processed_pixel_values, labels=labels, **kwargs)
# Function to convert existing model
def convert_to_preprocessing_model(original_model_path, output_path):
"""
Convert an existing DINOv2 model to include preprocessing.
Args:
original_model_path: Path to original model
output_path: Path to save converted model
"""
print(f"Converting {original_model_path} to include preprocessing...")
# Load original model
original_model = Dinov2ForImageClassification.from_pretrained(original_model_path)
# Create new model with same config
preprocessing_model = FontClassifierWithPreprocessing(original_model.config)
# Copy all weights
preprocessing_model.load_state_dict(original_model.state_dict())
# Save the new model
preprocessing_model.save_pretrained(output_path, safe_serialization=True)
# Copy processor config (unchanged)
from transformers import AutoImageProcessor
processor = AutoImageProcessor.from_pretrained(original_model_path)
processor.save_pretrained(output_path)
print(f"✅ Converted model saved to {output_path}")
return preprocessing_model
if __name__ == "__main__":
# Example: Convert existing model
convert_to_preprocessing_model(
"dchen0/font-classifier-v4",
"./font-classifier-with-preprocessing"
) |