Junho330 commited on
Commit
bc1967c
·
verified ·
1 Parent(s): 92f540d

Delete mdoel_dinov2_dual.py

Browse files
Files changed (1) hide show
  1. mdoel_dinov2_dual.py +0 -89
mdoel_dinov2_dual.py DELETED
@@ -1,89 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import copy
4
-
5
- from transformers import Dinov2Model, Dinov2Config
6
- class DualChannelDINOv2Model(Dinov2Model):
7
- """
8
- A modified DINOv2 model that uses DualChannelDINOv2Attention
9
- for each self-attention layer.
10
- """
11
- config_class = Dinov2Config
12
- def __init__(self, config: Dinov2Config):
13
- super().__init__(config)
14
- self.encoder = add_dual_channel_attention_to_dino(self.encoder)
15
-
16
- def add_dual_channel_attention_to_dino(dino_encoder: nn.Module):
17
- config = dino_encoder.config
18
-
19
- for idx, layer in enumerate(dino_encoder.layer):
20
- old_attn = layer.attention
21
-
22
- dual_attn = DualChannelDINOv2Attention(
23
- attention_base=old_attn,
24
- config=config,
25
- )
26
-
27
- layer.attention = dual_attn
28
-
29
- return dino_encoder
30
-
31
-
32
- class DualChannelDINOv2Attention(nn.Module):
33
- """
34
- A wrapper that keeps two instances of DINOv2Attention:
35
- 1) attention_base (frozen)
36
- 2) attention_plus (trainable)
37
- and fuses their outputs via a learnable alpha.
38
- """
39
- def __init__(self, attention_base: nn.Module, config):
40
- super().__init__()
41
- self.attention_base = attention_base
42
- self.attention_plus = copy.deepcopy(attention_base)
43
-
44
- # A learnable alpha parameter (vector of size [hidden_size])
45
- self.alpha_param = nn.Parameter(torch.zeros(config.hidden_size))
46
-
47
- def forward(
48
- self,
49
- hidden_states: torch.Tensor,
50
- head_mask: torch.Tensor = None,
51
- output_attentions: bool = False
52
- ):
53
- """
54
- Runs both attention_base and attention_plus, fuses results.
55
- """
56
- # ---- base branch ----
57
- base_ret = self.attention_base(
58
- hidden_states,
59
- head_mask=head_mask,
60
- output_attentions=output_attentions
61
- )
62
- if isinstance(base_ret, tuple):
63
- base_out = base_ret[0]
64
- base_attn = base_ret[1] if len(base_ret) > 1 else None
65
- else:
66
- base_out, base_attn = base_ret, None
67
-
68
- # ---- plus branch ----
69
- plus_ret = self.attention_plus(
70
- hidden_states,
71
- head_mask=head_mask,
72
- output_attentions=output_attentions
73
- )
74
- if isinstance(plus_ret, tuple):
75
- plus_out = plus_ret[0]
76
- plus_attn = plus_ret[1] if len(plus_ret) > 1 else None
77
- else:
78
- plus_out, plus_attn = plus_ret, None
79
-
80
- # ---- fuse outputs ----
81
- alpha = torch.sigmoid(self.alpha_param).view(1, 1, -1).to(dtype=base_out.dtype)
82
- fused_out = alpha * base_out + (1.0 - alpha) * plus_out
83
-
84
- if output_attentions:
85
- # Base의 attention map을 그대로 반환 (가장 안전한 호환 방식)
86
- return fused_out, base_attn
87
- else:
88
- # Hugging Face 규약: 길이 1짜리 튜플
89
- return (fused_out,)