|
|
""" |
|
|
Custom DINOv2 model that includes pad_to_square preprocessing in the forward pass. |
|
|
This allows inference endpoints to automatically apply correct preprocessing. |
|
|
""" |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from transformers import Dinov2ForImageClassification |
|
|
|
|
|
|
|
|
class FontClassifierWithPreprocessing(Dinov2ForImageClassification): |
|
|
""" |
|
|
DINOv2 model that automatically applies pad_to_square preprocessing. |
|
|
|
|
|
This model can be deployed to Inference Endpoints and will automatically |
|
|
handle preprocessing in the forward pass, so clients can send raw images. |
|
|
""" |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
|
|
|
|
|
|
self.register_buffer('image_mean', torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) |
|
|
self.register_buffer('image_std', torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) |
|
|
self.target_size = 224 |
|
|
|
|
|
def pad_to_square_tensor(self, images): |
|
|
""" |
|
|
Pad batch of images to square preserving aspect ratio. |
|
|
|
|
|
Args: |
|
|
images: Tensor of shape (B, C, H, W) |
|
|
Returns: |
|
|
Tensor of shape (B, C, max_size, max_size) |
|
|
""" |
|
|
B, C, H, W = images.shape |
|
|
max_size = max(H, W) |
|
|
|
|
|
if H == W == max_size: |
|
|
return images |
|
|
|
|
|
|
|
|
pad_h = max_size - H |
|
|
pad_w = max_size - W |
|
|
pad_top = pad_h // 2 |
|
|
pad_bottom = pad_h - pad_top |
|
|
pad_left = pad_w // 2 |
|
|
pad_right = pad_w - pad_left |
|
|
|
|
|
|
|
|
padded = F.pad(images, (pad_left, pad_right, pad_top, pad_bottom), value=0) |
|
|
|
|
|
return padded |
|
|
|
|
|
def preprocess_images(self, pixel_values): |
|
|
""" |
|
|
Apply full preprocessing pipeline to raw or partially processed images. |
|
|
|
|
|
Args: |
|
|
pixel_values: Tensor of shape (B, C, H, W) |
|
|
Returns: |
|
|
Preprocessed tensor ready for DINOv2 |
|
|
""" |
|
|
|
|
|
if pixel_values.dim() == 3: |
|
|
pixel_values = pixel_values.unsqueeze(0) |
|
|
|
|
|
|
|
|
if pixel_values.dtype != torch.float32: |
|
|
pixel_values = pixel_values.float() |
|
|
|
|
|
|
|
|
if pixel_values.max() > 1.0: |
|
|
pixel_values = pixel_values / 255.0 |
|
|
|
|
|
|
|
|
pixel_values = self.pad_to_square_tensor(pixel_values) |
|
|
|
|
|
|
|
|
if pixel_values.shape[-1] != self.target_size or pixel_values.shape[-2] != self.target_size: |
|
|
pixel_values = F.interpolate( |
|
|
pixel_values, |
|
|
size=(self.target_size, self.target_size), |
|
|
mode='bilinear', |
|
|
align_corners=False |
|
|
) |
|
|
|
|
|
|
|
|
pixel_values = (pixel_values - self.image_mean) / self.image_std |
|
|
|
|
|
return pixel_values |
|
|
|
|
|
def forward(self, pixel_values=None, labels=None, **kwargs): |
|
|
""" |
|
|
Forward pass with automatic preprocessing. |
|
|
|
|
|
Args: |
|
|
pixel_values: Raw or preprocessed images |
|
|
labels: Optional labels for training |
|
|
""" |
|
|
if pixel_values is None: |
|
|
raise ValueError("pixel_values must be provided") |
|
|
|
|
|
|
|
|
processed_pixel_values = self.preprocess_images(pixel_values) |
|
|
|
|
|
|
|
|
return super().forward(pixel_values=processed_pixel_values, labels=labels, **kwargs) |
|
|
|
|
|
|
|
|
def convert_to_preprocessing_model(original_model_path, output_path): |
|
|
""" |
|
|
Convert an existing DINOv2 model to include preprocessing. |
|
|
|
|
|
Args: |
|
|
original_model_path: Path to original model |
|
|
output_path: Path to save converted model |
|
|
""" |
|
|
print(f"Converting {original_model_path} to include preprocessing...") |
|
|
|
|
|
|
|
|
original_model = Dinov2ForImageClassification.from_pretrained(original_model_path) |
|
|
|
|
|
|
|
|
preprocessing_model = FontClassifierWithPreprocessing(original_model.config) |
|
|
|
|
|
|
|
|
preprocessing_model.load_state_dict(original_model.state_dict()) |
|
|
|
|
|
|
|
|
preprocessing_model.save_pretrained(output_path, safe_serialization=True) |
|
|
|
|
|
|
|
|
from transformers import AutoImageProcessor |
|
|
processor = AutoImageProcessor.from_pretrained(original_model_path) |
|
|
processor.save_pretrained(output_path) |
|
|
|
|
|
print(f"✅ Converted model saved to {output_path}") |
|
|
|
|
|
return preprocessing_model |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
convert_to_preprocessing_model( |
|
|
"dchen0/font-classifier-v4", |
|
|
"./font-classifier-with-preprocessing" |
|
|
) |