yermandy commited on
Commit
fb28b8b
·
verified ·
1 Parent(s): fee6528

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_gend.py +182 -0
modeling_gend.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from PIL import Image
5
+ from transformers import PretrainedConfig, PreTrainedModel
6
+
7
+
8
+ class LinearProbe(nn.Module):
9
+ def __init__(self, input_dim, num_classes, normalize_inputs=False, detach_classifier_inputs=False):
10
+ super().__init__()
11
+ self.linear = nn.Linear(input_dim, num_classes)
12
+ self.normalize_inputs = normalize_inputs
13
+
14
+ def forward(self, x: torch.Tensor, **kwargs):
15
+ return self.linear(x)
16
+
17
+
18
+ class CLIPEncoder(nn.Module):
19
+ def __init__(self, model_name="openai/clip-vit-large-patch14"):
20
+ super().__init__()
21
+
22
+ from transformers import CLIPModel, CLIPProcessor
23
+
24
+ try:
25
+ self._preprocess = CLIPProcessor.from_pretrained(model_name)
26
+ except Exception:
27
+ self._preprocess = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
28
+
29
+ clip: CLIPModel = CLIPModel.from_pretrained(model_name)
30
+
31
+ # take vision model from CLIP, maps image to vision_embed_dim
32
+ self.vision_model = clip.vision_model
33
+
34
+ self.model_name = model_name
35
+
36
+ self.features_dim = self.vision_model.config.hidden_size
37
+
38
+ # take visual_projection, maps vision_embed_dim to projection_dim
39
+ self.visual_projection = clip.visual_projection
40
+
41
+ def preprocess(self, image: Image) -> torch.Tensor:
42
+ return self._preprocess(images=image, return_tensors="pt")["pixel_values"][0]
43
+
44
+ def forward(self, preprocessed_images: torch.Tensor) -> torch.Tensor:
45
+ return self.vision_model(preprocessed_images).pooler_output
46
+
47
+ def get_features_dim(self):
48
+ return self.features_dim
49
+
50
+
51
+ class DINOEncoder(nn.Module):
52
+ def __init__(self, model_name="facebook/dinov2-with-registers-base"):
53
+ """
54
+ See models in src/config.py
55
+ """
56
+
57
+ super().__init__()
58
+
59
+ from transformers import AutoImageProcessor, AutoModel, Dinov2Model, Dinov2WithRegistersModel
60
+
61
+ self._preprocess = AutoImageProcessor.from_pretrained(model_name)
62
+ self.backbone: Dinov2Model | Dinov2WithRegistersModel = AutoModel.from_pretrained(model_name)
63
+
64
+ self.features_dim = self.backbone.config.hidden_size
65
+
66
+ def preprocess(self, image: Image) -> torch.Tensor:
67
+ return self._preprocess(images=image, return_tensors="pt")["pixel_values"][0]
68
+
69
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
70
+ return self.backbone(inputs).last_hidden_state[:, 0]
71
+
72
+ def get_features_dim(self) -> int:
73
+ return self.features_dim
74
+
75
+
76
+ class PerceptionEncoder(nn.Module):
77
+ def __init__(
78
+ self,
79
+ model_name="vit_pe_core_large_patch14_336",
80
+ img_size: None | int = None,
81
+ ):
82
+ super().__init__()
83
+
84
+ if img_size is not None:
85
+ dynamic_img_size = True
86
+
87
+ import timm
88
+ from timm.models.eva import Eva
89
+
90
+ self.backbone: Eva = timm.create_model(
91
+ model_name,
92
+ pretrained=True,
93
+ dynamic_img_size=dynamic_img_size,
94
+ )
95
+
96
+ # Get model specific transforms (normalization, resize)
97
+ data_config = timm.data.resolve_model_data_config(self.backbone)
98
+
99
+ if img_size is not None:
100
+ data_config["input_size"] = (3, img_size, img_size)
101
+
102
+ self._preprocess = timm.data.create_transform(**data_config, is_training=False)
103
+
104
+ # Remove head
105
+ self.backbone.head = nn.Identity()
106
+
107
+ self.features_dim = self.backbone.num_features
108
+
109
+ def preprocess(self, image: Image.Image) -> torch.Tensor:
110
+ return self._preprocess(image)
111
+
112
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
113
+ return self.backbone(inputs)
114
+
115
+ def get_features_dim(self) -> int:
116
+ return self.features_dim
117
+
118
+
119
+ class GenDConfig(PretrainedConfig):
120
+ model_type = "GenD"
121
+
122
+ def __init__(self, backbone: str = "openai/clip-vit-large-patch14", head: str = "linear", **kwargs):
123
+ super().__init__(**kwargs)
124
+ self.backbone = backbone
125
+ self.head = head
126
+
127
+
128
+ class GenD(PreTrainedModel):
129
+ config_class = GenDConfig
130
+
131
+ def __init__(self, config):
132
+ super().__init__(config)
133
+
134
+ self.head = config.head
135
+ self.backbone = config.backbone
136
+ self.config = config
137
+
138
+ self._init_feature_extractor()
139
+ self._init_head()
140
+
141
+ def _init_feature_extractor(self):
142
+ backbone = self.backbone
143
+ backbone_lowercase = backbone.lower()
144
+
145
+ if "clip" in backbone_lowercase:
146
+ self.feature_extractor = CLIPEncoder(backbone)
147
+
148
+ elif "vit_pe" in backbone_lowercase:
149
+ from src.encoders.perception_encoder import PerceptionEncoder
150
+
151
+ self.feature_extractor = PerceptionEncoder(backbone, self.config.backbone_args.img_size)
152
+
153
+ elif "dino" in backbone_lowercase:
154
+ from src.encoders.dino_encoder import DINOEncoder
155
+
156
+ if self.config.backbone_args is not None:
157
+ merge_cls_token_with_patches = self.config.backbone_args.merge_cls_token_with_patches
158
+ else:
159
+ merge_cls_token_with_patches = None
160
+
161
+ self.feature_extractor = DINOEncoder(backbone, merge_cls_token_with_patches)
162
+
163
+ else:
164
+ raise ValueError(f"Unknown backbone: {backbone}")
165
+
166
+ def _init_head(self):
167
+ features_dim = self.feature_extractor.get_features_dim()
168
+
169
+ match self.head:
170
+ case "linear":
171
+ self.model = LinearProbe(features_dim, 2)
172
+
173
+ case "LinearNorm":
174
+ self.model = LinearProbe(features_dim, 2, True)
175
+
176
+ case _:
177
+ raise ValueError(f"Unknown head: {self.head}")
178
+
179
+ def forward(self, inputs: torch.Tensor):
180
+ features = self.feature_extractor(inputs)
181
+ outputs = self.model.forward(features)
182
+ return outputs