AbstractPhil commited on
Commit
8b6e6b3
·
verified ·
1 Parent(s): cbc9a0f

Update modeling_caption_bert.py

Browse files
Files changed (1) hide show
  1. modeling_caption_bert.py +132 -44
modeling_caption_bert.py CHANGED
@@ -1,79 +1,167 @@
1
  # ============================================================================
2
- # CaptionEncoder: Standalone Consensus-Distilled Caption Embedding Model
3
- #
4
- # Produces 768-dim L2-normalized embeddings in geometric consensus space.
5
- # Trained via distillation from 5-BERT pentachoron consensus.
6
- # No expert models needed at inference.
7
  #
8
  # Usage:
9
- # from caption_encoder import CaptionEncoder
10
- # model = CaptionEncoder()
11
- # model.load_state_dict(torch.load("best_model.pt"))
12
- # # tokenize with bert-base-uncased tokenizer
13
- # embedding = model(input_ids, attention_mask) # (B, 768) L2-normalized
 
 
 
14
  # ============================================================================
15
 
16
  import torch
17
  import torch.nn as nn
18
  import torch.nn.functional as F
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
 
21
- class CaptionEncoder(nn.Module):
22
  """
23
- Standalone transformer caption encoder.
24
- No pretrained weights required. Trained via geometric consensus distillation.
25
 
26
- The embedding space is the geometric intersection of 5 BERT-family models:
27
- BERT-base, ModernBERT-base, RoBERTa-base, ALBERT-base-v2, DistilBERT-base.
28
- Aligned via whitened Procrustes rotation. Regularized by pentachoron CV.
29
 
30
- At inference: bert-base-uncased tokenizer + this model.
31
- Output: (B, 768) L2-normalized embedding in consensus space.
 
32
  """
33
- def __init__(self, vocab_size=30522, max_len=8192, d_model=384,
34
- n_heads=6, n_layers=6, d_ff=1536, output_dim=768,
35
- dropout=0.1, pad_token_id=0):
36
- super().__init__()
37
- self.pad_token_id = pad_token_id
38
- self.d_model = d_model
39
- self.max_len = max_len
40
-
41
- self.token_emb = nn.Embedding(vocab_size, d_model, padding_idx=pad_token_id)
42
- self.pos_emb = nn.Embedding(max_len, d_model)
43
- self.emb_norm = nn.LayerNorm(d_model)
44
- self.emb_drop = nn.Dropout(dropout)
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  encoder_layer = nn.TransformerEncoderLayer(
47
- d_model=d_model, nhead=n_heads, dim_feedforward=d_ff,
48
- dropout=dropout, activation="gelu", batch_first=True,
49
- norm_first=True)
50
- self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
 
 
 
 
 
 
51
 
 
52
  self.output_proj = nn.Sequential(
53
- nn.Linear(d_model, d_model),
54
  nn.GELU(),
55
- nn.LayerNorm(d_model),
56
- nn.Linear(d_model, output_dim),
57
  )
58
 
59
- def forward(self, input_ids, attention_mask=None):
 
 
60
  B, L = input_ids.shape
61
- positions = torch.arange(L, device=input_ids.device).unsqueeze(0)
62
 
 
 
63
  x = self.token_emb(input_ids) + self.pos_emb(positions)
64
  x = self.emb_drop(self.emb_norm(x))
65
 
 
66
  if attention_mask is not None:
67
- kpm = ~attention_mask.bool()
68
  else:
69
- kpm = (input_ids == self.pad_token_id)
70
 
71
- x = self.encoder(x, src_key_padding_mask=kpm)
72
 
 
73
  if attention_mask is not None:
74
  mask = attention_mask.unsqueeze(-1).float()
75
  else:
76
- mask = (~kpm).unsqueeze(-1).float()
77
  pooled = (x * mask).sum(1) / mask.sum(1).clamp(min=1)
78
 
79
- return F.normalize(self.output_proj(pooled), dim=-1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # ============================================================================
2
+ # CaptionBERT-8192: HuggingFace AutoModel-Compatible Implementation
 
 
 
 
3
  #
4
  # Usage:
5
+ # from transformers import AutoModel, AutoTokenizer
6
+ # model = AutoModel.from_pretrained("AbstractPhil/geolip-captionbert-8192",
7
+ # trust_remote_code=True)
8
+ # tokenizer = AutoTokenizer.from_pretrained("AbstractPhil/geolip-captionbert-8192")
9
+ # inputs = tokenizer("A cat on a windowsill", return_tensors="pt",
10
+ # padding=True, truncation=True, max_length=512)
11
+ # outputs = model(**inputs)
12
+ # embedding = outputs.last_hidden_state # (B, 768) L2-normalized
13
  # ============================================================================
14
 
15
  import torch
16
  import torch.nn as nn
17
  import torch.nn.functional as F
18
+ from transformers import PretrainedConfig, PreTrainedModel
19
+
20
+
21
+ class CaptionBertConfig(PretrainedConfig):
22
+ model_type = "caption_bert"
23
+
24
+ def __init__(
25
+ self,
26
+ vocab_size=30522,
27
+ max_position_embeddings=8192,
28
+ hidden_size=384,
29
+ num_attention_heads=6,
30
+ num_hidden_layers=6,
31
+ intermediate_size=1536,
32
+ output_dim=768,
33
+ hidden_dropout_prob=0.1,
34
+ pad_token_id=0,
35
+ **kwargs,
36
+ ):
37
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
38
+ self.vocab_size = vocab_size
39
+ self.max_position_embeddings = max_position_embeddings
40
+ self.hidden_size = hidden_size
41
+ self.num_attention_heads = num_attention_heads
42
+ self.num_hidden_layers = num_hidden_layers
43
+ self.intermediate_size = intermediate_size
44
+ self.output_dim = output_dim
45
+ self.hidden_dropout_prob = hidden_dropout_prob
46
 
47
 
48
+ class CaptionBertModel(PreTrainedModel):
49
  """
50
+ Consensus-distilled caption encoder.
 
51
 
52
+ Produces L2-normalized 768-dim embeddings in the geometric consensus
53
+ space of 5 BERT-family models (BERT, ModernBERT, RoBERTa, ALBERT, DistilBERT).
 
54
 
55
+ Output:
56
+ last_hidden_state: (B, output_dim) L2-normalized embedding
57
+ pooler_output: (B, output_dim) same as last_hidden_state (for compatibility)
58
  """
59
+ config_class = CaptionBertConfig
 
 
 
 
 
 
 
 
 
 
 
60
 
61
+ def __init__(self, config):
62
+ super().__init__(config)
63
+ self.config = config
64
+
65
+ # Embeddings
66
+ self.token_emb = nn.Embedding(
67
+ config.vocab_size, config.hidden_size,
68
+ padding_idx=config.pad_token_id)
69
+ self.pos_emb = nn.Embedding(
70
+ config.max_position_embeddings, config.hidden_size)
71
+ self.emb_norm = nn.LayerNorm(config.hidden_size)
72
+ self.emb_drop = nn.Dropout(config.hidden_dropout_prob)
73
+
74
+ # Transformer encoder
75
  encoder_layer = nn.TransformerEncoderLayer(
76
+ d_model=config.hidden_size,
77
+ nhead=config.num_attention_heads,
78
+ dim_feedforward=config.intermediate_size,
79
+ dropout=config.hidden_dropout_prob,
80
+ activation="gelu",
81
+ batch_first=True,
82
+ norm_first=True,
83
+ )
84
+ self.encoder = nn.TransformerEncoder(
85
+ encoder_layer, num_layers=config.num_hidden_layers)
86
 
87
+ # Output projection to consensus space
88
  self.output_proj = nn.Sequential(
89
+ nn.Linear(config.hidden_size, config.hidden_size),
90
  nn.GELU(),
91
+ nn.LayerNorm(config.hidden_size),
92
+ nn.Linear(config.hidden_size, config.output_dim),
93
  )
94
 
95
+ self.post_init()
96
+
97
+ def forward(self, input_ids=None, attention_mask=None, **kwargs):
98
  B, L = input_ids.shape
99
+ device = input_ids.device
100
 
101
+ # Embed
102
+ positions = torch.arange(L, device=device).unsqueeze(0)
103
  x = self.token_emb(input_ids) + self.pos_emb(positions)
104
  x = self.emb_drop(self.emb_norm(x))
105
 
106
+ # Transformer with padding mask
107
  if attention_mask is not None:
108
+ key_padding_mask = ~attention_mask.bool()
109
  else:
110
+ key_padding_mask = (input_ids == self.config.pad_token_id)
111
 
112
+ x = self.encoder(x, src_key_padding_mask=key_padding_mask)
113
 
114
+ # Mean pool over non-padding tokens
115
  if attention_mask is not None:
116
  mask = attention_mask.unsqueeze(-1).float()
117
  else:
118
+ mask = (~key_padding_mask).unsqueeze(-1).float()
119
  pooled = (x * mask).sum(1) / mask.sum(1).clamp(min=1)
120
 
121
+ # Project and normalize
122
+ embedding = F.normalize(self.output_proj(pooled), dim=-1)
123
+
124
+ # Return in HuggingFace-compatible format
125
+ return type('Output', (), {
126
+ 'last_hidden_state': embedding,
127
+ 'pooler_output': embedding,
128
+ })()
129
+
130
+ def encode(self, texts, tokenizer=None, max_length=512, batch_size=128,
131
+ device=None):
132
+ """
133
+ Convenience method: raw text → L2-normalized embeddings.
134
+
135
+ Args:
136
+ texts: str or list of str
137
+ tokenizer: AutoTokenizer instance (loads default if None)
138
+ max_length: max token length
139
+ batch_size: encoding batch size
140
+ device: torch device
141
+
142
+ Returns:
143
+ (N, 768) L2-normalized tensor
144
+ """
145
+ if isinstance(texts, str):
146
+ texts = [texts]
147
+
148
+ if tokenizer is None:
149
+ from transformers import AutoTokenizer
150
+ tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
151
+
152
+ if device is None:
153
+ device = next(self.parameters()).device
154
+
155
+ self.eval()
156
+ all_emb = []
157
+ with torch.no_grad():
158
+ for i in range(0, len(texts), batch_size):
159
+ batch = texts[i:i+batch_size]
160
+ inputs = tokenizer(
161
+ batch, max_length=max_length, padding="max_length",
162
+ truncation=True, return_tensors="pt"
163
+ ).to(device)
164
+ out = self(input_ids=inputs["input_ids"],
165
+ attention_mask=inputs["attention_mask"])
166
+ all_emb.append(out.last_hidden_state.cpu())
167
+ return torch.cat(all_emb)