Junho330 commited on
Commit
b7103a0
·
1 Parent(s): 11a2bae
Files changed (3) hide show
  1. config.json +24 -0
  2. dino.safetensors +3 -0
  3. modeling_dinov2_dual.py +97 -0
config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "DualChannelDINOv2Model"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.0,
6
+ "drop_path_rate": 0.0,
7
+ "hidden_act": "gelu",
8
+ "hidden_dropout_prob": 0.0,
9
+ "hidden_size": 768,
10
+ "image_size": 518,
11
+ "initializer_range": 0.02,
12
+ "layer_norm_eps": 1e-06,
13
+ "layerscale_value": 1.0,
14
+ "mlp_ratio": 4,
15
+ "model_type": "dinov2",
16
+ "num_attention_heads": 12,
17
+ "num_channels": 3,
18
+ "num_hidden_layers": 12,
19
+ "patch_size": 14,
20
+ "qkv_bias": true,
21
+ "torch_dtype": "float32",
22
+ "transformers_version": "4.31.0.dev0",
23
+ "use_swiglu_ffn": false
24
+ }
dino.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f7718d343aa0369b8e730bbbb0f3b68516668869f3c8fe79945934572268088a
3
+ size 229915824
modeling_dinov2_dual.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import copy
4
+
5
+ from transformers import PreTrainedModel, Dinov2Config, Dinov2Model
6
+ class DualChannelDINOv2Model(PreTrainedModel):
7
+ """
8
+ A modified DINOv2 model that uses DualChannelDINOv2Attention
9
+ for each self-attention layer.
10
+ """
11
+ config_class = Dinov2Config
12
+ def __init__(self, config: Dinov2Config):
13
+ super().__init__(config)
14
+ self.encoder = Dinov2Model(config).encoder
15
+ self.encoder = add_dual_channel_attention_to_dino(self.encoder)
16
+
17
+ def add_dual_channel_attention_to_dino(dino_encoder: nn.Module):
18
+ """
19
+ Traverse DINOv2Model, wrapping each layer's self_attn (DINOv2Attention)
20
+ with DualChannelDINOv2Attention. The original attention is frozen, and
21
+ a second trainable copy is created.
22
+ """
23
+ config = dino_encoder.config
24
+
25
+ for idx, layer in enumerate(dino_encoder.layer):
26
+ old_attn = layer.attention
27
+
28
+ # Create our wrapper with the old (frozen) attn
29
+ dual_attn = DualChannelDINOv2Attention(
30
+ attention_base=old_attn,
31
+ config=config,
32
+ layer_idx=idx
33
+ )
34
+
35
+ layer.attention = dual_attn
36
+
37
+ return dino_encoder
38
+
39
+
40
+ class DualChannelDINOv2Attention(nn.Module):
41
+ """
42
+ A wrapper that keeps two instances of DINOv2Attention:
43
+ 1) attention_base (frozen)
44
+ 2) attention_plus (trainable)
45
+ and fuses their outputs via a learnable alpha.
46
+ """
47
+ def __init__(self, attention_base: nn.Module, config):
48
+ super().__init__(config)
49
+ self.attention_base = attention_base
50
+ self.attention_plus = copy.deepcopy(attention_base)
51
+
52
+ # A learnable alpha parameter (vector of size [hidden_size])
53
+ self.alpha_param = nn.Parameter(torch.zeros(config.hidden_size))
54
+
55
+ def forward(
56
+ self,
57
+ hidden_states: torch.Tensor,
58
+ head_mask: torch.Tensor = None,
59
+ output_attentions: bool = False
60
+ ):
61
+ """
62
+ Runs both attention_base and attention_plus, fuses results.
63
+ """
64
+ # ---- base branch ----
65
+ base_ret = self.attention_base(
66
+ hidden_states,
67
+ head_mask=head_mask,
68
+ output_attentions=output_attentions
69
+ )
70
+ if isinstance(base_ret, tuple):
71
+ base_out = base_ret[0]
72
+ base_attn = base_ret[1] if len(base_ret) > 1 else None
73
+ else:
74
+ base_out, base_attn = base_ret, None
75
+
76
+ # ---- plus branch ----
77
+ plus_ret = self.attention_plus(
78
+ hidden_states,
79
+ head_mask=head_mask,
80
+ output_attentions=output_attentions
81
+ )
82
+ if isinstance(plus_ret, tuple):
83
+ plus_out = plus_ret[0]
84
+ plus_attn = plus_ret[1] if len(plus_ret) > 1 else None
85
+ else:
86
+ plus_out, plus_attn = plus_ret, None
87
+
88
+ # ---- fuse outputs ----
89
+ alpha = torch.sigmoid(self.alpha_param).view(1, 1, -1).to(dtype=base_out.dtype)
90
+ fused_out = alpha * base_out + (1.0 - alpha) * plus_out
91
+
92
+ if output_attentions:
93
+ # Base의 attention map을 그대로 반환 (가장 안전한 호환 방식)
94
+ return fused_out, base_attn
95
+ else:
96
+ # Hugging Face 규약: 길이 1짜리 튜플
97
+ return (fused_out,)