Leonardo6 commited on
Commit
44a7352
·
verified ·
1 Parent(s): 2dd0287

Upload connectors.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. connectors.py +169 -0
connectors.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from abc import ABC, abstractmethod
3
+ from typing import Any, override
4
+
5
+ import torch.nn as nn
6
+ from torch import Tensor
7
+
8
+
9
+ class Connector(nn.Module, ABC):
10
+ """
11
+ Abstract base class for all connector modules.
12
+ Connectors are responsible for projecting visual features to a space
13
+ compatible with text features.
14
+ """
15
+
16
+ def __init__(
17
+ self,
18
+ config: Any,
19
+ image_hidden_size: int,
20
+ text_hidden_size: int,
21
+ ) -> None:
22
+ super().__init__()
23
+ self.config: Any = config
24
+ self.name: str = self.config.name
25
+ self.image_hidden_size: int = image_hidden_size
26
+ self.text_hidden_size: int = text_hidden_size
27
+ self.projection_layer: nn.Module = self._build_projection_layer()
28
+
29
+ @abstractmethod
30
+ def _build_projection_layer(self) -> nn.Module:
31
+ pass
32
+
33
+ @override
34
+ def forward(self, visual_features: Tensor) -> Tensor:
35
+ return self.projection_layer(visual_features)
36
+
37
+
38
+ class IdentityConnector(Connector):
39
+ def __init__(
40
+ self,
41
+ config: Any,
42
+ image_hidden_size: int,
43
+ text_hidden_size: int,
44
+ ) -> None:
45
+ if image_hidden_size != text_hidden_size:
46
+ raise ValueError(
47
+ f"IdentityConnector initialized with image_hidden_size ({image_hidden_size}) "
48
+ f"!= text_hidden_size ({text_hidden_size}). Features will pass through unchanged."
49
+ )
50
+ super().__init__(config, image_hidden_size, text_hidden_size)
51
+
52
+ @override
53
+ def _build_projection_layer(self) -> nn.Module:
54
+ return nn.Identity()
55
+
56
+
57
+ class LinearConnector(Connector):
58
+ def __init__(
59
+ self,
60
+ config: Any,
61
+ image_hidden_size: int,
62
+ text_hidden_size: int,
63
+ ) -> None:
64
+ super().__init__(config, image_hidden_size, text_hidden_size)
65
+
66
+ @override
67
+ def _build_projection_layer(self) -> nn.Module:
68
+ return nn.Linear(
69
+ self.image_hidden_size,
70
+ self.text_hidden_size,
71
+ )
72
+
73
+
74
+ class MLPConnector(Connector):
75
+ ACTIVATION_MAP: dict[str, type[nn.Module]] = {
76
+ "relu": nn.ReLU,
77
+ "gelu": nn.GELU,
78
+ "silu": nn.SiLU, # Swish/SiLU
79
+ "tanh": nn.Tanh,
80
+ "sigmoid": nn.Sigmoid,
81
+ }
82
+
83
+ @override
84
+ def __init__(
85
+ self,
86
+ config: Any,
87
+ image_hidden_size: int,
88
+ text_hidden_size: int,
89
+ ) -> None:
90
+ self.num_layers: int = 2
91
+ self.activation_name: str = "gelu"
92
+
93
+ # Parse num_layers and activation_name from the connector's name string
94
+ self._parse_config_name(config.name)
95
+
96
+ super().__init__(config, image_hidden_size, text_hidden_size)
97
+
98
+ def _parse_config_name(self, name: str) -> None:
99
+ pattern = r"mlp_(\d+)_(\w+)" # e.g., mlp_2_gelu, mlp_3_relu
100
+ match = re.match(pattern, name)
101
+ if match:
102
+ try:
103
+ self.num_layers = int(match.group(1))
104
+ self.activation_name = match.group(2).lower()
105
+ if self.activation_name not in self.ACTIVATION_MAP:
106
+ raise ValueError(
107
+ f"MLPConnector: Activation '{self.activation_name}' from name '{name}' is not recognized. "
108
+ f"Falling back to default activation '{MLPConnector.activation_name}'. "
109
+ f"Supported: {list(self.ACTIVATION_MAP.keys())}"
110
+ )
111
+ self.activation_name = "gelu" # Fallback to default if parsed name is invalid
112
+ except ValueError as e:
113
+ raise ValueError(
114
+ f"MLPConnector: Could not parse num_layers from '{match.group(1)}' in name '{name}'. "
115
+ f"Using default num_layers: {self.num_layers}."
116
+ ) from e
117
+ else:
118
+ raise ValueError(
119
+ f"MLPConnector name '{name}' does not match pattern 'mlp_NUMLAYERS_ACTIVATION'. "
120
+ f"Using defaults: num_layers={self.num_layers}, activation_name='{self.activation_name}'."
121
+ )
122
+
123
+ @override
124
+ def _build_projection_layer(self) -> nn.Module:
125
+ if self.num_layers < 1:
126
+ raise ValueError(
127
+ f"MLPConnector: Number of layers must be at least 1, got {self.num_layers}"
128
+ )
129
+
130
+ activation_class = self.ACTIVATION_MAP.get(self.activation_name)
131
+ if activation_class is None:
132
+ # This case should ideally be handled by _parse_config_name fallback,
133
+ # but as a safeguard:
134
+ raise ValueError(
135
+ f"MLPConnector: Unsupported activation function '{self.activation_name}'. "
136
+ f"Supported activations: {list(self.ACTIVATION_MAP.keys())}. "
137
+ f"Defaulting to GELU."
138
+ )
139
+ activation_class = nn.GELU # Fallback
140
+
141
+ layers: list[nn.Module] = []
142
+
143
+ for i in range(self.num_layers):
144
+ # The first layer maps from image_hidden_size to text_hidden_size.
145
+ # Subsequent hidden layers map from text_hidden_size to text_hidden_size.
146
+ # The final layer also outputs text_hidden_size.
147
+ input_dim = self.image_hidden_size if i == 0 else self.text_hidden_size
148
+ output_dim = (
149
+ self.text_hidden_size
150
+ ) # All layers in the MLP project towards/maintain text_hidden_size
151
+
152
+ layers.append(nn.Linear(input_dim, output_dim))
153
+
154
+ # Add activation function for all layers except the last one
155
+ if i < self.num_layers - 1:
156
+ layers.append(activation_class())
157
+
158
+ return nn.Sequential(*layers)
159
+
160
+
161
+ # --- Connector Mapping and Exports ---
162
+
163
+ # This map is used by your _build_connector function to instantiate the correct connector type.
164
+ # The keys ('identity', 'linear', 'mlp') should match the `connector_config.type` values.
165
+ connector_map: dict[str, type[Connector]] = {
166
+ "identity": IdentityConnector,
167
+ "linear": LinearConnector,
168
+ "mlp": MLPConnector,
169
+ }