font_classifier_v4 / font_classifier_with_preprocessing.py
dchen0's picture
Add model with built-in server-side preprocessing
ecb5b6d verified
"""
Custom DINOv2 model that includes pad_to_square preprocessing in the forward pass.
This allows inference endpoints to automatically apply correct preprocessing.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import Dinov2ForImageClassification
class FontClassifierWithPreprocessing(Dinov2ForImageClassification):
"""
DINOv2 model that automatically applies pad_to_square preprocessing.
This model can be deployed to Inference Endpoints and will automatically
handle preprocessing in the forward pass, so clients can send raw images.
"""
def __init__(self, config):
super().__init__(config)
# Store preprocessing parameters
self.register_buffer('image_mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
self.register_buffer('image_std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
self.target_size = 224
def pad_to_square_tensor(self, images):
"""
Pad batch of images to square preserving aspect ratio.
Args:
images: Tensor of shape (B, C, H, W)
Returns:
Tensor of shape (B, C, max_size, max_size)
"""
B, C, H, W = images.shape
max_size = max(H, W)
if H == W == max_size:
return images # Already square
# Calculate padding
pad_h = max_size - H
pad_w = max_size - W
pad_top = pad_h // 2
pad_bottom = pad_h - pad_top
pad_left = pad_w // 2
pad_right = pad_w - pad_left
# Apply padding (left, right, top, bottom)
padded = F.pad(images, (pad_left, pad_right, pad_top, pad_bottom), value=0)
return padded
def preprocess_images(self, pixel_values):
"""
Apply full preprocessing pipeline to raw or partially processed images.
Args:
pixel_values: Tensor of shape (B, C, H, W)
Returns:
Preprocessed tensor ready for DINOv2
"""
# Ensure we have a batch dimension
if pixel_values.dim() == 3:
pixel_values = pixel_values.unsqueeze(0)
# Convert to float if needed
if pixel_values.dtype != torch.float32:
pixel_values = pixel_values.float()
# Normalize to [0, 1] if values are in [0, 255]
if pixel_values.max() > 1.0:
pixel_values = pixel_values / 255.0
# Apply pad_to_square
pixel_values = self.pad_to_square_tensor(pixel_values)
# Resize to target size
if pixel_values.shape[-1] != self.target_size or pixel_values.shape[-2] != self.target_size:
pixel_values = F.interpolate(
pixel_values,
size=(self.target_size, self.target_size),
mode='bilinear',
align_corners=False
)
# Apply ImageNet normalization
pixel_values = (pixel_values - self.image_mean) / self.image_std
return pixel_values
def forward(self, pixel_values=None, labels=None, **kwargs):
"""
Forward pass with automatic preprocessing.
Args:
pixel_values: Raw or preprocessed images
labels: Optional labels for training
"""
if pixel_values is None:
raise ValueError("pixel_values must be provided")
# Apply preprocessing automatically
processed_pixel_values = self.preprocess_images(pixel_values)
# Call parent forward with preprocessed images
return super().forward(pixel_values=processed_pixel_values, labels=labels, **kwargs)
# Function to convert existing model
def convert_to_preprocessing_model(original_model_path, output_path):
"""
Convert an existing DINOv2 model to include preprocessing.
Args:
original_model_path: Path to original model
output_path: Path to save converted model
"""
print(f"Converting {original_model_path} to include preprocessing...")
# Load original model
original_model = Dinov2ForImageClassification.from_pretrained(original_model_path)
# Create new model with same config
preprocessing_model = FontClassifierWithPreprocessing(original_model.config)
# Copy all weights
preprocessing_model.load_state_dict(original_model.state_dict())
# Save the new model
preprocessing_model.save_pretrained(output_path, safe_serialization=True)
# Copy processor config (unchanged)
from transformers import AutoImageProcessor
processor = AutoImageProcessor.from_pretrained(original_model_path)
processor.save_pretrained(output_path)
print(f"✅ Converted model saved to {output_path}")
return preprocessing_model
if __name__ == "__main__":
# Example: Convert existing model
convert_to_preprocessing_model(
"dchen0/font-classifier-v4",
"./font-classifier-with-preprocessing"
)