GOTOCR2-4bit-BNB / image_encoder.py
Thatphum's picture
Upload FastViTImageEncoder
c88b91b verified
"""
HF-compatible wrapper that turns the FastViT backbone into a pure *image encoder*.
Output: a single (B, embed_dim) vector obtained with the built-in GlobalPool2D head.
"""
import torch
from transformers import PreTrainedModel, PretrainedConfig
from .mci import fastvithd, GlobalPool2D # imports your backbone factory
# ----------------------- Config -----------------------
class FastViTImageConfig(PretrainedConfig):
"""Minimal config so HF knows the image size & embed dim."""
model_type = "fastvit_image_encoder"
def __init__(
self,
image_size: int = 1024,
embed_dim: int = 3072, # channels after conv_exp
patch_size: int = 16,
**kwargs
):
self.image_size = image_size
self.embed_dim = embed_dim
self.patch_size = patch_size
super().__init__(**kwargs)
# ----------------------- Model ------------------------
class FastViTImageEncoder(PreTrainedModel):
"""
Wraps FastViT-HD and exposes an `.embeddings` output;
no text tower, no CLIP logits, only a pooled image embedding.
"""
config_class = FastViTImageConfig
main_input_name = "pixel_values"
def __init__(self, config: FastViTImageConfig):
super().__init__(config)
# We **keep** GlobalPool2D by asking for `num_classes = embed_dim`
# (FastViT replaces the classifier with GlobalPool2D in that case).
self.backbone = fastvithd(num_classes=0)
self.backbone.head = GlobalPool2D(
in_dim = 3072,
out_dim = 768
)
# HF helper that registers weights for bf16/half-precision etc.
self.post_init()
# ------------------------------------------
def forward(self, pixel_values, return_dict=True, **unused):
"""
Args:
pixel_values: (B, 3, H, W) tensor (already resized/normalized).
Returns:
Dict with a single key `"embeddings"` of shape (B, embed_dim).
"""
# FastViT forward returns the pooled tensor directly because
# `num_classes == embed_dim` and head == GlobalPool2D.
embeddings = self.backbone(pixel_values) # (B, embed_dim)
if not return_dict:
return (embeddings,)
return {"embeddings": embeddings}
def forward(self, images):
return self.forward_images(images)
def feature_select(self, image_forward_outs):
# Features from penultimate layer
image_features = image_forward_outs["image_embeddings"]
# Reshape 4D tensor to 3D
B, C, H, W = image_features.shape
image_features = image_features.reshape(B, C, H*W)
image_features = image_features.transpose(1, 2)
return image_features
def forward_images(self, images):
if type(images) is list:
image_features = []
for image in images:
image_forward_out = self.backbone(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), return_image_embeddings=True)
image_feature = self.feature_select(image_forward_out).to(image.dtype)
image_features.append(image_feature)
else:
image_forward_outs = self.backbone(images.to(device=self.device, dtype=self.dtype), return_image_embeddings=True)
image_features = self.feature_select(image_forward_outs).to(images.dtype)
return image_features