File size: 4,881 Bytes
56ef371
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from vlm_fo1.model.multimodal_encoder.base_encoder import AbsVisionTower
from vlm_fo1.model.multimodal_encoder.davit.configuration_davit import DavitConfig
from vlm_fo1.model.multimodal_encoder.davit.configs import model_configs, img_cfg
from vlm_fo1.model.multimodal_encoder.davit.modeling_davit import DaViT
from vlm_fo1.model.multimodal_encoder.davit.image_processing_clip import CLIPImageProcessor

# Auxiliary DaViT-based vision tower for multi-modal encoder framework.
# This class manages configuration, processing, and dynamic instantiation of DaViT models.
class DavitVisionTower(AbsVisionTower):
    def __init__(self, vision_tower_name, args, delay_load=False, image_size=768, aspect_ratio='squash'):
        """
        Args:
            vision_tower_name: Identifier string for model variant (usually a file name or config section).
            args: Parent MM model/global config (currently ignored).
            delay_load: If True, only config is loaded, not the weights/model (for e.g., lazy instantiation).
            image_size: Target size to which images are resized (unless aspect_ratio=='dynamic').
            aspect_ratio: Controls how input aspect ratio is handled ('squash', 'dynamic', etc.).
        """
        super().__init__()
        self.is_loaded = False
        self.vision_tower_name = vision_tower_name
        self.aspect_ratio = aspect_ratio
        self.image_size = image_size

        # In this implementation, training flag is ignored (always uses pretrained weights).
        is_train = False
        # if not delay_load:
        #     self.load_model(is_train, self.image_size, self.aspect_ratio)
        # else:
        #     # Only load/prepare configuration (not model weights or modules)
        #     cfg_dict = model_configs[self.vision_tower_name.split('/')[-1].replace('.pth', '')]
        #     vision_cfg = DavitConfig.from_dict(cfg_dict)
        #     vision_cfg.image_size = image_size
        #     self.cfg_only = vision_cfg
        self.load_model(is_train, self.image_size, self.aspect_ratio)
        
    def load_model(self, is_train=False, image_size=768, aspect_ratio='squash'):
        """
        Actually loads the DaViT model (with weights) and its image processor.
        Sets up resizing/aspect handling as needed.
        """
        cfg_dict = model_configs[self.vision_tower_name.split('/')[-1].replace('.pth', '')]
        vision_cfg = DavitConfig.from_dict(cfg_dict)
        vision_cfg.image_size = image_size
        self.image_tower = DaViT.from_config(config=vision_cfg, enable_checkpoint=True)
        self.image_tower.config = vision_cfg
        img_cfg['resize_mode'] = aspect_ratio
        # If using 'dynamic' aspect ratio, disable resizing for the processor
        if aspect_ratio == 'dynamic':    # dynamic aspect ratio means no resizing, use the original image size, and the image_size parameter is not used
            img_cfg['do_resize'] = False
        self.image_processor = CLIPImageProcessor(**img_cfg)

        self.is_loaded = True

    def forward(self, images):
        """
        Runs the auxiliary DaViT encoder.
        Args:
            images: Torch tensor, or list of tensors, of images to encode.
        Returns:
            List of image feature outputs (typically 4-stage outputs per image).
        """
        # If input is a list of images, encode each separately.
        if type(images) is list:
            image_features = []
            for image in images:
                # Forward pass: returns 4-stage outputs; caller must handle downstream selection/merging.
                image_features.append(self.image_tower.forward(image.to(device=self.device, dtype=self.dtype))) # this returns 4 stage output
            return image_features
        else:
            # Single image: compute features, return as a length-1 list for consistency.
            # image_features = self.image_tower.forward(images.to(device=self.device, dtype=self.dtype)) # this returns 4 stage output
            # return [image_features]  # return the last layer for now
            raise NotImplementedError

    @property
    def dtype(self):
        # Expose main tensor dtype to external utilities (e.g., for caller to move data to right dtype).
        return self.image_tower.convs[0].proj.weight.dtype

    @property
    def device(self):
        # Expose main parameter device so inputs and other dependent modules use matching device.
        return self.image_tower.convs[0].proj.weight.device

    @property
    def config(self):
        # Get configuration in loaded or 'config only' state
        if self.is_loaded:
            return self.image_tower.config
        else:
            return self.cfg_only

    @property
    def hidden_size(self):
        # Hidden size: sum of embedding dims (all multi-stage outputs).
        return sum(self.image_tower.embed_dims)