Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,881 Bytes
56ef371 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
from vlm_fo1.model.multimodal_encoder.base_encoder import AbsVisionTower
from vlm_fo1.model.multimodal_encoder.davit.configuration_davit import DavitConfig
from vlm_fo1.model.multimodal_encoder.davit.configs import model_configs, img_cfg
from vlm_fo1.model.multimodal_encoder.davit.modeling_davit import DaViT
from vlm_fo1.model.multimodal_encoder.davit.image_processing_clip import CLIPImageProcessor
# Auxiliary DaViT-based vision tower for multi-modal encoder framework.
# This class manages configuration, processing, and dynamic instantiation of DaViT models.
class DavitVisionTower(AbsVisionTower):
def __init__(self, vision_tower_name, args, delay_load=False, image_size=768, aspect_ratio='squash'):
"""
Args:
vision_tower_name: Identifier string for model variant (usually a file name or config section).
args: Parent MM model/global config (currently ignored).
delay_load: If True, only config is loaded, not the weights/model (for e.g., lazy instantiation).
image_size: Target size to which images are resized (unless aspect_ratio=='dynamic').
aspect_ratio: Controls how input aspect ratio is handled ('squash', 'dynamic', etc.).
"""
super().__init__()
self.is_loaded = False
self.vision_tower_name = vision_tower_name
self.aspect_ratio = aspect_ratio
self.image_size = image_size
# In this implementation, training flag is ignored (always uses pretrained weights).
is_train = False
# if not delay_load:
# self.load_model(is_train, self.image_size, self.aspect_ratio)
# else:
# # Only load/prepare configuration (not model weights or modules)
# cfg_dict = model_configs[self.vision_tower_name.split('/')[-1].replace('.pth', '')]
# vision_cfg = DavitConfig.from_dict(cfg_dict)
# vision_cfg.image_size = image_size
# self.cfg_only = vision_cfg
self.load_model(is_train, self.image_size, self.aspect_ratio)
def load_model(self, is_train=False, image_size=768, aspect_ratio='squash'):
"""
Actually loads the DaViT model (with weights) and its image processor.
Sets up resizing/aspect handling as needed.
"""
cfg_dict = model_configs[self.vision_tower_name.split('/')[-1].replace('.pth', '')]
vision_cfg = DavitConfig.from_dict(cfg_dict)
vision_cfg.image_size = image_size
self.image_tower = DaViT.from_config(config=vision_cfg, enable_checkpoint=True)
self.image_tower.config = vision_cfg
img_cfg['resize_mode'] = aspect_ratio
# If using 'dynamic' aspect ratio, disable resizing for the processor
if aspect_ratio == 'dynamic': # dynamic aspect ratio means no resizing, use the original image size, and the image_size parameter is not used
img_cfg['do_resize'] = False
self.image_processor = CLIPImageProcessor(**img_cfg)
self.is_loaded = True
def forward(self, images):
"""
Runs the auxiliary DaViT encoder.
Args:
images: Torch tensor, or list of tensors, of images to encode.
Returns:
List of image feature outputs (typically 4-stage outputs per image).
"""
# If input is a list of images, encode each separately.
if type(images) is list:
image_features = []
for image in images:
# Forward pass: returns 4-stage outputs; caller must handle downstream selection/merging.
image_features.append(self.image_tower.forward(image.to(device=self.device, dtype=self.dtype))) # this returns 4 stage output
return image_features
else:
# Single image: compute features, return as a length-1 list for consistency.
# image_features = self.image_tower.forward(images.to(device=self.device, dtype=self.dtype)) # this returns 4 stage output
# return [image_features] # return the last layer for now
raise NotImplementedError
@property
def dtype(self):
# Expose main tensor dtype to external utilities (e.g., for caller to move data to right dtype).
return self.image_tower.convs[0].proj.weight.dtype
@property
def device(self):
# Expose main parameter device so inputs and other dependent modules use matching device.
return self.image_tower.convs[0].proj.weight.device
@property
def config(self):
# Get configuration in loaded or 'config only' state
if self.is_loaded:
return self.image_tower.config
else:
return self.cfg_only
@property
def hidden_size(self):
# Hidden size: sum of embedding dims (all multi-stage outputs).
return sum(self.image_tower.embed_dims)
|