""" Custom DINOv2 model that includes pad_to_square preprocessing in the forward pass. This allows inference endpoints to automatically apply correct preprocessing. """ import torch import torch.nn as nn import torch.nn.functional as F from transformers import Dinov2ForImageClassification class FontClassifierWithPreprocessing(Dinov2ForImageClassification): """ DINOv2 model that automatically applies pad_to_square preprocessing. This model can be deployed to Inference Endpoints and will automatically handle preprocessing in the forward pass, so clients can send raw images. """ def __init__(self, config): super().__init__(config) # Store preprocessing parameters self.register_buffer('image_mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) self.register_buffer('image_std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) self.target_size = 224 def pad_to_square_tensor(self, images): """ Pad batch of images to square preserving aspect ratio. Args: images: Tensor of shape (B, C, H, W) Returns: Tensor of shape (B, C, max_size, max_size) """ B, C, H, W = images.shape max_size = max(H, W) if H == W == max_size: return images # Already square # Calculate padding pad_h = max_size - H pad_w = max_size - W pad_top = pad_h // 2 pad_bottom = pad_h - pad_top pad_left = pad_w // 2 pad_right = pad_w - pad_left # Apply padding (left, right, top, bottom) padded = F.pad(images, (pad_left, pad_right, pad_top, pad_bottom), value=0) return padded def preprocess_images(self, pixel_values): """ Apply full preprocessing pipeline to raw or partially processed images. Args: pixel_values: Tensor of shape (B, C, H, W) Returns: Preprocessed tensor ready for DINOv2 """ # Ensure we have a batch dimension if pixel_values.dim() == 3: pixel_values = pixel_values.unsqueeze(0) # Convert to float if needed if pixel_values.dtype != torch.float32: pixel_values = pixel_values.float() # Normalize to [0, 1] if values are in [0, 255] if pixel_values.max() > 1.0: pixel_values = pixel_values / 255.0 # Apply pad_to_square pixel_values = self.pad_to_square_tensor(pixel_values) # Resize to target size if pixel_values.shape[-1] != self.target_size or pixel_values.shape[-2] != self.target_size: pixel_values = F.interpolate( pixel_values, size=(self.target_size, self.target_size), mode='bilinear', align_corners=False ) # Apply ImageNet normalization pixel_values = (pixel_values - self.image_mean) / self.image_std return pixel_values def forward(self, pixel_values=None, labels=None, **kwargs): """ Forward pass with automatic preprocessing. Args: pixel_values: Raw or preprocessed images labels: Optional labels for training """ if pixel_values is None: raise ValueError("pixel_values must be provided") # Apply preprocessing automatically processed_pixel_values = self.preprocess_images(pixel_values) # Call parent forward with preprocessed images return super().forward(pixel_values=processed_pixel_values, labels=labels, **kwargs) # Function to convert existing model def convert_to_preprocessing_model(original_model_path, output_path): """ Convert an existing DINOv2 model to include preprocessing. Args: original_model_path: Path to original model output_path: Path to save converted model """ print(f"Converting {original_model_path} to include preprocessing...") # Load original model original_model = Dinov2ForImageClassification.from_pretrained(original_model_path) # Create new model with same config preprocessing_model = FontClassifierWithPreprocessing(original_model.config) # Copy all weights preprocessing_model.load_state_dict(original_model.state_dict()) # Save the new model preprocessing_model.save_pretrained(output_path, safe_serialization=True) # Copy processor config (unchanged) from transformers import AutoImageProcessor processor = AutoImageProcessor.from_pretrained(original_model_path) processor.save_pretrained(output_path) print(f"✅ Converted model saved to {output_path}") return preprocessing_model if __name__ == "__main__": # Example: Convert existing model convert_to_preprocessing_model( "dchen0/font-classifier-v4", "./font-classifier-with-preprocessing" )