| | """ |
| | Custom DINOv2 model that includes pad_to_square preprocessing in the forward pass. |
| | This allows inference endpoints to automatically apply correct preprocessing. |
| | """ |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from transformers import Dinov2ForImageClassification |
| |
|
| |
|
| | class FontClassifierWithPreprocessing(Dinov2ForImageClassification): |
| | """ |
| | DINOv2 model that automatically applies pad_to_square preprocessing. |
| | |
| | This model can be deployed to Inference Endpoints and will automatically |
| | handle preprocessing in the forward pass, so clients can send raw images. |
| | """ |
| | |
| | def __init__(self, config): |
| | super().__init__(config) |
| | |
| | |
| | self.register_buffer('image_mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) |
| | self.register_buffer('image_std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) |
| | self.target_size = 224 |
| | |
| | def pad_to_square_tensor(self, images): |
| | """ |
| | Pad batch of images to square preserving aspect ratio. |
| | |
| | Args: |
| | images: Tensor of shape (B, C, H, W) |
| | Returns: |
| | Tensor of shape (B, C, max_size, max_size) |
| | """ |
| | B, C, H, W = images.shape |
| | max_size = max(H, W) |
| | |
| | if H == W == max_size: |
| | return images |
| | |
| | |
| | pad_h = max_size - H |
| | pad_w = max_size - W |
| | pad_top = pad_h // 2 |
| | pad_bottom = pad_h - pad_top |
| | pad_left = pad_w // 2 |
| | pad_right = pad_w - pad_left |
| | |
| | |
| | padded = F.pad(images, (pad_left, pad_right, pad_top, pad_bottom), value=0) |
| | |
| | return padded |
| | |
| | def preprocess_images(self, pixel_values): |
| | """ |
| | Apply full preprocessing pipeline to raw or partially processed images. |
| | |
| | Args: |
| | pixel_values: Tensor of shape (B, C, H, W) |
| | Returns: |
| | Preprocessed tensor ready for DINOv2 |
| | """ |
| | |
| | if pixel_values.dim() == 3: |
| | pixel_values = pixel_values.unsqueeze(0) |
| | |
| | |
| | if pixel_values.dtype != torch.float32: |
| | pixel_values = pixel_values.float() |
| | |
| | |
| | if pixel_values.max() > 1.0: |
| | pixel_values = pixel_values / 255.0 |
| | |
| | |
| | pixel_values = self.pad_to_square_tensor(pixel_values) |
| | |
| | |
| | if pixel_values.shape[-1] != self.target_size or pixel_values.shape[-2] != self.target_size: |
| | pixel_values = F.interpolate( |
| | pixel_values, |
| | size=(self.target_size, self.target_size), |
| | mode='bilinear', |
| | align_corners=False |
| | ) |
| | |
| | |
| | pixel_values = (pixel_values - self.image_mean) / self.image_std |
| | |
| | return pixel_values |
| | |
| | def forward(self, pixel_values=None, labels=None, **kwargs): |
| | """ |
| | Forward pass with automatic preprocessing. |
| | |
| | Args: |
| | pixel_values: Raw or preprocessed images |
| | labels: Optional labels for training |
| | """ |
| | if pixel_values is None: |
| | raise ValueError("pixel_values must be provided") |
| | |
| | |
| | processed_pixel_values = self.preprocess_images(pixel_values) |
| | |
| | |
| | return super().forward(pixel_values=processed_pixel_values, labels=labels, **kwargs) |
| |
|
| | |
| | def convert_to_preprocessing_model(original_model_path, output_path): |
| | """ |
| | Convert an existing DINOv2 model to include preprocessing. |
| | |
| | Args: |
| | original_model_path: Path to original model |
| | output_path: Path to save converted model |
| | """ |
| | print(f"Converting {original_model_path} to include preprocessing...") |
| | |
| | |
| | original_model = Dinov2ForImageClassification.from_pretrained(original_model_path) |
| | |
| | |
| | preprocessing_model = FontClassifierWithPreprocessing(original_model.config) |
| | |
| | |
| | preprocessing_model.load_state_dict(original_model.state_dict()) |
| | |
| | |
| | preprocessing_model.save_pretrained(output_path, safe_serialization=True) |
| | |
| | |
| | from transformers import AutoImageProcessor |
| | processor = AutoImageProcessor.from_pretrained(original_model_path) |
| | processor.save_pretrained(output_path) |
| | |
| | print(f"✅ Converted model saved to {output_path}") |
| | |
| | return preprocessing_model |
| |
|
| | if __name__ == "__main__": |
| | |
| | convert_to_preprocessing_model( |
| | "dchen0/font-classifier-v4", |
| | "./font-classifier-with-preprocessing" |
| | ) |