dcher95 commited on
Commit
9cabbed
·
verified ·
1 Parent(s): a8b5f22

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. cosa/compute_embeddings.py +148 -0
  2. cosa/cosa.ckpt +3 -0
  3. cosa/model.py +290 -0
  4. cosa/text_encoder.py +374 -0
cosa/compute_embeddings.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import torch
4
+ from tqdm import tqdm
5
+ from transformers import (
6
+ AutoTokenizer, AutoModel,
7
+ BertTokenizer, BertModel,
8
+ CLIPTokenizer, CLIPTextModel,
9
+ T5Tokenizer, T5EncoderModel
10
+ )
11
+
12
+ import sys
13
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "osm_clip")))
14
+ from model import OSMBind
15
+
16
+
17
+ def average_pool(last_hidden_states, attention_mask):
18
+ """Computes average pooling of hidden states, masking padding tokens."""
19
+ masked_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
20
+ return masked_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
21
+
22
+
23
+ def get_tokenizer_and_model(encoder_type='bert', checkpoint_path=None, taglist_path = None, tagvocab_path = None):
24
+ if encoder_type == 'bert':
25
+ model_name = 'bert-base-uncased'
26
+ tokenizer = BertTokenizer.from_pretrained(model_name)
27
+ model = BertModel.from_pretrained(model_name)
28
+ embedding_fn = lambda outputs, batch_dict: outputs.pooler_output.squeeze()
29
+
30
+ elif encoder_type == 'clip':
31
+ model_name = 'openai/clip-vit-large-patch14'
32
+ tokenizer = CLIPTokenizer.from_pretrained(model_name)
33
+ model = CLIPTextModel.from_pretrained(model_name)
34
+
35
+ def clip_embedding_fn(outputs, batch_dict):
36
+ input_ids = batch_dict['input_ids']
37
+ eos_token_id = tokenizer.eos_token_id
38
+ seq_lengths = (input_ids == eos_token_id).nonzero(as_tuple=True)[1]
39
+
40
+ embeddings = []
41
+ for i in range(input_ids.size(0)):
42
+ eos_pos = seq_lengths[i] if i < len(seq_lengths) else (input_ids[i] != tokenizer.pad_token_id).sum() - 1
43
+ embeddings.append(outputs.last_hidden_state[i, eos_pos, :])
44
+ return torch.stack(embeddings)
45
+
46
+ embedding_fn = clip_embedding_fn
47
+
48
+ elif encoder_type == 'e5':
49
+ model_name = 'intfloat/e5-base-v2'
50
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
51
+ model = AutoModel.from_pretrained(model_name)
52
+ embedding_fn = lambda outputs, batch_dict: average_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
53
+
54
+ elif encoder_type == 't5':
55
+ model_name = 't5-base'
56
+ tokenizer = T5Tokenizer.from_pretrained(model_name)
57
+ model = T5EncoderModel.from_pretrained(model_name)
58
+ embedding_fn = lambda outputs, batch_dict: average_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
59
+
60
+ elif 'osm' in encoder_type:
61
+ text_backbone = encoder_type.split('-')[1] if '-' in encoder_type else 'clip'
62
+ model = OSMBind(taglist_path=taglist_path, tagvocab_path=tagvocab_path, text_backbone=text_backbone)
63
+ ckpt = torch.load(checkpoint_path, map_location='cpu')
64
+ model.load_state_dict(ckpt['state_dict'], strict=False)
65
+ model.eval().cuda()
66
+ tokenizer = None
67
+
68
+ def osm_embedding_fn(outputs, batch_dict):
69
+ return model.text_encoder.encode_batch(batch_dict['sentences'])
70
+
71
+ embedding_fn = osm_embedding_fn
72
+
73
+ else:
74
+ raise ValueError(f"Unsupported encoder_type: {encoder_type}")
75
+
76
+ model.eval()
77
+ return tokenizer, model, embedding_fn
78
+
79
+
80
+ def generate_embeddings(taglist_path, tag_vocab_path, output_path,
81
+ encoder_type='bert', checkpoint_path=None):
82
+ # Load taglist and vocab
83
+ taglist = torch.load(taglist_path, weights_only = True) # list of tuples of tag indices
84
+ tag_vocab = torch.load(tag_vocab_path, weights_only = True)
85
+ tag_index = {v: k for k, v in tag_vocab.items()} # index -> tag string
86
+
87
+ # Convert taglist tuples to "sentences" of tag strings
88
+ sentences = []
89
+ for tl in taglist:
90
+ words = [tag_index[idx] for idx in tl]
91
+ sentences.append(" ".join(words))
92
+
93
+ # Optional prompt formatting
94
+ if encoder_type == 'e5':
95
+ sentences = [f"query: {s}" for s in sentences]
96
+ elif encoder_type == 't5':
97
+ sentences = [f"embedding: {s}" for s in sentences]
98
+
99
+ # Load model
100
+ tokenizer, model, embedding_fn = get_tokenizer_and_model(encoder_type, checkpoint_path, taglist_path = taglist_path, tagvocab_path = tag_vocab_path)
101
+ device = next(model.parameters()).device if hasattr(model, 'parameters') else torch.device('cpu')
102
+
103
+ # Generate embeddings
104
+ embeddings = []
105
+ print("Encoding taglists...")
106
+ for sentence in tqdm(sentences):
107
+ if 'osm' in encoder_type:
108
+ batch_dict = {'sentences': [sentence]}
109
+ outputs = None
110
+ else:
111
+ inputs = tokenizer([sentence], return_tensors='pt', padding=True, truncation=True)
112
+ batch_dict = {k: v.to(device) for k, v in inputs.items()}
113
+ outputs = model(**batch_dict)
114
+
115
+ with torch.inference_mode():
116
+ emb = embedding_fn(outputs, batch_dict)
117
+ if emb.ndim == 1:
118
+ emb = emb.unsqueeze(0)
119
+ embeddings.append(emb.cpu())
120
+
121
+ embeddings = torch.cat(embeddings, dim=0)
122
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
123
+ torch.save(embeddings, output_path)
124
+ print(f"Saved {len(sentences)} taglist embeddings to {output_path}")
125
+
126
+
127
+ # ========================
128
+ # Command Line Interface
129
+ # ========================
130
+ if __name__ == "__main__":
131
+ parser = argparse.ArgumentParser(description="Generate embeddings for taglists")
132
+ parser.add_argument("--taglist_path", type=str, required=True, help="Path to taglist_vocab.pt")
133
+ parser.add_argument("--tag_vocab_path", type=str, required=True, help="Path to tag_vocab.pt")
134
+ parser.add_argument("--output_path", type=str, required=True, help="Path to save embeddings tensor")
135
+ parser.add_argument("--encoder_type", type=str,
136
+ choices=["bert", "clip", "e5", "t5", "osm-clip", "osm-e5", "osm-bert"],
137
+ default="bert")
138
+ parser.add_argument("--checkpoint_path", type=str, default=None, help="Optional checkpoint for OSMBind")
139
+
140
+ args = parser.parse_args()
141
+
142
+ generate_embeddings(
143
+ taglist_path=args.taglist_path,
144
+ tag_vocab_path=args.tag_vocab_path,
145
+ output_path=args.output_path,
146
+ encoder_type=args.encoder_type,
147
+ checkpoint_path=args.checkpoint_path
148
+ )
cosa/cosa.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:981a8ec6c089d019dbe54afd34693d3617db8b28837cf5adf013702563b6f73a
3
+ size 2365975368
cosa/model.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import os
4
+ import numpy as np
5
+ import torch.nn.functional as F
6
+ import pytorch_lightning as pl
7
+ from datasets import OSMDataset
8
+ from torch.utils.data import DataLoader
9
+ import random
10
+ from typing import Optional, List, Tuple, Literal
11
+ from image_encoder import SatlasPretrainEncoder
12
+ from text_encoder import TextEncoder
13
+ from orthogonal_adamw import OrthogonalAdamW
14
+ from configs.config_e5 import config
15
+ from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
16
+ from lightning.pytorch.loggers import WandbLogger
17
+ from utils import generate_tag_poly_pairs
18
+ import matplotlib.pyplot as plt
19
+ import io
20
+ import wandb
21
+ from PIL import Image
22
+
23
+
24
+ # This performs a typical InfoNCE loss
25
+ def contrastive_loss(image_feats: torch.Tensor, text_feats: torch.Tensor, logit_scale: torch.Tensor) -> torch.Tensor:
26
+ logits = torch.matmul(image_feats, text_feats.t()) * logit_scale
27
+ labels = torch.arange(logits.size(0), device=logits.device)
28
+
29
+ return F.cross_entropy(logits, labels), logits
30
+
31
+
32
+ class OSMBind(pl.LightningModule):
33
+ def __init__(self, train_dataset=None, val_dataset=None, **kwargs):
34
+ super().__init__()
35
+ self.train_dataset = train_dataset
36
+ self.val_dataset = val_dataset
37
+
38
+ self.image_encoder = SatlasPretrainEncoder(fpn=True, model_name="Aerial_SwinB_SI",
39
+ out_dim=768, num_extra_fpn_layers=4)
40
+ taglist_vocab = torch.load(kwargs.get("taglist_path"), weights_only = True)
41
+ tag_vocab_inverted = torch.load(kwargs.get("tagvocab_path"), weights_only = True) # str -> int
42
+ tag_vocab = {v: k for k, v in tag_vocab_inverted.items()} # int -> str
43
+ self.text_encoder = TextEncoder(taglist_vocab, tag_vocab,
44
+ model_name=kwargs.get("text_backbone"))
45
+ # for param in self.text_encoder.parameters():
46
+ # param.requires_grad = False
47
+
48
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) # softer scale for misaligned encoders
49
+
50
+ self.batch_size = kwargs.get("batch_size")
51
+ self.num_workers = kwargs.get("num_workers")
52
+ self.lr = kwargs.get("lr", 1e-4)
53
+ self.num_samples = kwargs.get("num_samples") # number of OSM classes sampled
54
+ self.ort_grad = kwargs.get("ort_grad")
55
+
56
+ def forward(self, sat_img: torch.Tensor, pixel_tensor: torch.Tensor):
57
+ full_image_feats = self.image_encoder(sat_img) # [B, D, H', W']
58
+ sampled_tag_tensor, image_poly_feats = generate_tag_poly_pairs(pixel_tensor, full_image_feats, K=self.num_samples) # [K], [K, D]
59
+ text_sampled_feats = self.text_encoder(sampled_tag_tensor) # [K, D]
60
+
61
+ return image_poly_feats, text_sampled_feats # [K, D], [K, D]
62
+
63
+ def shared_step(self, batch):
64
+ sat_img, pixel_tensor = batch
65
+ image_poly_feats, text_sampled_feats = self(sat_img, pixel_tensor) # [K, D], [K, D]
66
+
67
+ # contrastive loss for whole batch
68
+ image_feats_norm = F.normalize(image_poly_feats, dim=1)
69
+ text_feats_norm = F.normalize(text_sampled_feats, dim=1)
70
+ logit_scale = self.logit_scale.exp()
71
+ loss, logits = contrastive_loss(image_feats_norm, text_feats_norm,
72
+ logit_scale=logit_scale)
73
+ return loss, logits
74
+
75
+ def log_similarity_matrix(self, logits):
76
+ mat = logits.detach().cpu().numpy()
77
+ fig, ax = plt.subplots(figsize=(6,6))
78
+ cax = ax.matshow(mat, cmap="viridis")
79
+ fig.colorbar(cax)
80
+ ax.set_xlabel("Text samples")
81
+ ax.set_ylabel("Image samples")
82
+ ax.set_title("Similarity Matrix")
83
+
84
+ buf = io.BytesIO()
85
+ plt.savefig(buf, format='png')
86
+ buf.seek(0)
87
+ plt.close(fig)
88
+
89
+ # ✅ Fix: Convert buffer to PIL Image
90
+ image = Image.open(buf)
91
+
92
+ if isinstance(self.logger, WandbLogger):
93
+ self.logger.experiment.log({
94
+ "similarity_matrix": wandb.Image(image),
95
+ "global_step": self.global_step
96
+ })
97
+
98
+ def training_step(self, batch, batch_idx):
99
+ loss, logits = self.shared_step(batch)
100
+ self.log('train_loss', loss, sync_dist=True, prog_bar=True, on_epoch=True, batch_size=self.batch_size)
101
+ self.log('temperature', self.logit_scale.exp().item(), prog_bar=True, on_epoch=True)
102
+ if self.global_step % 500 == 0:
103
+ self.log_similarity_matrix(logits)
104
+ # Log histogram of similarity scores every step
105
+ if self.logger and hasattr(self.logger.experiment, "log"):
106
+ self.logger.experiment.log({"logits_hist": wandb.Histogram(logits.detach().cpu().numpy())})
107
+
108
+ # Optionally log mean and max of logits for monitoring
109
+ self.log("logits_mean", logits.mean(), on_step=True, on_epoch=False, prog_bar=True)
110
+ self.log("logits_max", logits.max(), on_step=True, on_epoch=False, prog_bar=True)
111
+ return loss
112
+
113
+ def on_train_batch_end(self, outputs, batch, batch_idx):
114
+ min_log_scale = np.log(1 / 1.0)
115
+ max_log_scale = np.log(1 / 0.01)
116
+ self.logit_scale.data.clamp_(min_log_scale, max_log_scale)
117
+
118
+ def on_after_backward(self):
119
+ if self.global_rank == 0 and self.current_epoch == 0:
120
+ for name, param in self.named_parameters():
121
+ if param.requires_grad and param.grad is None:
122
+ print(f"⚠️ Unused parameter: {name}")
123
+
124
+ def validation_step(self, batch, batch_idx):
125
+ loss, _ = self.shared_step(batch)
126
+ self.log('val_loss', loss, sync_dist=True, prog_bar=True, on_epoch=True, batch_size=self.batch_size)
127
+ return loss
128
+
129
+ def train_dataloader(self):
130
+ if self.train_dataset is None:
131
+ raise ValueError("This model was initialized without a training dataset.")
132
+ return DataLoader(self.train_dataset,
133
+ batch_size=self.batch_size,
134
+ num_workers=self.num_workers,
135
+ shuffle=True,
136
+ persistent_workers=False)
137
+
138
+ def val_dataloader(self):
139
+ if self.val_dataset is None:
140
+ raise ValueError("This model was initialized without a validation dataset.")
141
+ return DataLoader(self.val_dataset,
142
+ batch_size=self.batch_size,
143
+ num_workers=self.num_workers,
144
+ shuffle=False,
145
+ persistent_workers=False)
146
+
147
+ def configure_optimizers(self):
148
+ params = self.parameters()
149
+ if self.ort_grad:
150
+ self.optim = OrthogonalAdamW(
151
+ params,
152
+ lr=self.lr,
153
+ betas=(0.9, 0.98),
154
+ beta_ort=0.9,
155
+ eps=1e-6,
156
+ weight_decay=0.01
157
+ )
158
+ else:
159
+ self.optim = torch.optim.AdamW(
160
+ params,
161
+ lr=self.lr,
162
+ betas=(0.9, 0.98),
163
+ eps=1e-6,
164
+ weight_decay=0.01
165
+ )
166
+ self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
167
+ optimizer=self.optim,
168
+ T_0=20
169
+ )
170
+
171
+ return [self.optim], [self.scheduler]
172
+
173
+ def sim_map_inf(self, sat_image: torch.Tensor, raw_text: str) -> torch.Tensor:
174
+ """
175
+ Args:
176
+ sat_image: [1, 3, 512, 512] tensor (already normalized)
177
+ raw_text: str, e.g., "building"
178
+
179
+ Returns:
180
+ sim_map: [512, 512] similarity map between image and text embedding
181
+ """
182
+ assert sat_image.dim() == 4 and sat_image.size(0) == 1, "Expected input of shape [1, 3, H, W]"
183
+
184
+ # Step 1: Extract spatial features
185
+ with torch.no_grad():
186
+ # image features
187
+ feat_map = self.image_encoder(sat_image) # [1, D, H', W']
188
+ feat_map = feat_map.squeeze(0) # [D, H', W']
189
+ feat_map_upsampled = F.interpolate(feat_map.unsqueeze(0), size=(512, 512), mode='bilinear', align_corners=False).squeeze(0) # [D, 512, 512]
190
+ feat_map_upsampled = F.normalize(feat_map_upsampled, dim=0) # [D, 512, 512]
191
+
192
+ # text features
193
+ text_feat = self.text_encoder.encode_raw_text(raw_text)
194
+
195
+ # cosine sim
196
+ text_feat = F.normalize(text_feat, dim=0)
197
+ feat_map_upsampled = F.normalize(feat_map_upsampled, dim=0)
198
+ sim_map = torch.einsum('chw,c->hw', feat_map_upsampled, text_feat) # [512, 512]
199
+
200
+ return sim_map
201
+
202
+ def encode_text(self, text: str) -> torch.Tensor:
203
+ with torch.no_grad():
204
+ return self.text_encoder.encode_raw_text(text)
205
+
206
+ def encode_image(self, image: torch.Tensor) -> torch.Tensor:
207
+ with torch.no_grad():
208
+ return self.image_encoder(image)
209
+
210
+ def seed_everything(seed=42):
211
+ """
212
+ seed: int
213
+ """
214
+ torch.manual_seed(seed)
215
+ torch.cuda.manual_seed_all(seed)
216
+ np.random.seed(seed)
217
+ random.seed(seed)
218
+ torch.backends.cudnn.deterministic = True
219
+ torch.backends.cudnn.benchmark = False
220
+ os.environ["PYTHONHASHSEED"] = str(seed)
221
+
222
+ if __name__=='__main__':
223
+ import warnings
224
+ warnings.filterwarnings("ignore")
225
+ torch.set_warn_always(False)
226
+
227
+ seed_everything()
228
+ train_dataset = OSMDataset(metadata_path = config.train_csv,
229
+ image_dir=config.sat_img_dir,
230
+ pixel_tensor_dir=config.pixel_tensors_dir,
231
+ mode='train')
232
+ val_dataset = OSMDataset(metadata_path = config.val_csv,
233
+ image_dir=config.sat_img_dir,
234
+ pixel_tensor_dir=config.pixel_tensors_dir,
235
+ mode='val')
236
+
237
+ # from torch.utils.data import Subset
238
+ # train_dataset = Subset(train_dataset, range(1000))
239
+ # val_dataset = Subset(val_dataset, range(200))
240
+
241
+ kwargs = {
242
+ 'batch_size':config.batch_size,
243
+ 'num_workers': config.num_workers,
244
+ 'num_samples': config.num_contrastive_samples,
245
+ 'ort_grad': config.ort_grad,
246
+ 'lr': config.lr,
247
+ 'taglist_vocab_path': config.taglist_vocab_path,
248
+ 'tag_vocab_path': config.tag_vocab_path,
249
+ 'text_backbone': config.text_backbone
250
+ }
251
+
252
+ model = OSMBind(train_dataset, val_dataset, **kwargs)
253
+ torch.cuda.empty_cache()
254
+
255
+ checkpoint_path = '/data/b.j.wei/rendersynth/osm_clip/checkpoints/osmclip_e5/osmclip_config_e5-epoch=39-val_loss=3.23.ckpt'
256
+ if checkpoint_path:
257
+ ckpt = torch.load(checkpoint_path, map_location='cpu')
258
+ model.load_state_dict(ckpt['state_dict'])
259
+
260
+ checkpoint = ModelCheckpoint(
261
+ monitor='val_loss',
262
+ dirpath=config.save_dir,
263
+ filename=config.filename,
264
+ mode='min',
265
+ save_top_k=1,
266
+ every_n_epochs=1
267
+ )
268
+
269
+ early_stop_callback = EarlyStopping(
270
+ monitor='val_loss',
271
+ patience=15,
272
+ mode='min'
273
+ )
274
+
275
+ logger = WandbLogger(project="osmclip",
276
+ name=f"{config.experiment_name}")
277
+
278
+ trainer = pl.Trainer(
279
+ accelerator='gpu',
280
+ devices=config.devices,
281
+ strategy='ddp',
282
+ max_epochs=config.max_epochs,
283
+ num_nodes=1,
284
+ callbacks=[checkpoint, early_stop_callback],
285
+ accumulate_grad_batches=config.accumulate_grad_batches,
286
+ log_every_n_steps=5,
287
+ logger = logger #wandb logger
288
+ )
289
+
290
+ trainer.fit(model)
cosa/text_encoder.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import (
3
+ AutoTokenizer, AutoModel,
4
+ BertTokenizer, BertModel,
5
+ CLIPTokenizer, CLIPTextModel
6
+ )
7
+ import torch.nn as nn
8
+ import pytorch_lightning as pl
9
+ from typing import List
10
+ from abc import ABC, abstractmethod
11
+ import random
12
+
13
+ import os
14
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
15
+
16
+
17
+ def taglist_index_to_sentence(taglist_vocab, tag_vocab, taglist_indices, subsample: bool = True):
18
+ """
19
+ Convert a tensor or list of taglist indices to a list of tag sentences.
20
+ Optionally, randomly shuffle and sample a subset of tags for each sentence.
21
+
22
+ Args:
23
+ taglist_vocab: List of tuples of tag IDs.
24
+ tag_vocab: Dictionary mapping tag ID to tag string.
25
+ taglist_indices: Tensor or list of indices into taglist_vocab.
26
+ seed: Random seed for reproducibility.
27
+ subsample: If True, randomly subsample tags in each sentence.
28
+
29
+ Returns:
30
+ tag_sentences: List of strings (tag sentences).
31
+ """
32
+ if isinstance(taglist_indices, torch.Tensor):
33
+ taglist_indices = taglist_indices.view(-1).tolist()
34
+
35
+ tag_sentences = []
36
+
37
+ for idx in taglist_indices:
38
+ tag_ids = taglist_vocab[idx]
39
+ tags = [tag_vocab[tid].lower().replace('=', ' ') for tid in tag_ids]
40
+
41
+ if subsample and len(tags) > 1:
42
+ n_sample = random.randint(1, len(tags)) # Choose how many tags to keep
43
+ tags = random.sample(tags, n_sample) # Sample without replacement
44
+
45
+ random.shuffle(tags) # Randomize order
46
+ sentence = ' '.join(tags)
47
+ tag_sentences.append(sentence)
48
+
49
+ return tag_sentences
50
+
51
+
52
+ def average_pool(last_hidden_states, attention_mask):
53
+ masked_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
54
+ return masked_hidden.sum(dim=1) / attention_mask.sum(dim=1, keepdim=True)
55
+
56
+
57
+ class BaseTextEncoder(nn.Module, ABC):
58
+ def __init__(self, model_name: str):
59
+ super().__init__()
60
+ self.model_name = model_name
61
+ self.tokenizer = None
62
+ self.model = None
63
+ self.embedding_dim = None
64
+
65
+ @abstractmethod
66
+ def encode(self, sentences: List[str], device: str = 'cpu') -> torch.Tensor:
67
+ """
68
+ Encode a list of sentences into a tensor of embeddings.
69
+ Must be implemented by subclasses.
70
+ """
71
+ pass
72
+
73
+ class BertTextEncoder(BaseTextEncoder):
74
+ def __init__(self, model_name='bert-base-uncased'):
75
+ super().__init__(model_name)
76
+ self.tokenizer = BertTokenizer.from_pretrained(model_name)
77
+ self.model = BertModel.from_pretrained(model_name)
78
+ self.embedding_dim = self.model.config.hidden_size
79
+
80
+ def encode(self, sentences, device='cpu'):
81
+ self.model.to(device)
82
+ inputs = self.tokenizer(sentences, return_tensors='pt', padding=True, truncation=True).to(device)
83
+ return self.model(**inputs).pooler_output
84
+
85
+
86
+ class CLIPTextEncoder(BaseTextEncoder):
87
+ def __init__(self, model_name='openai/clip-vit-large-patch14', local_tokenizer_path=None):
88
+ super().__init__(model_name)
89
+ local_tokenizer_path = "/u/cherd/.cache/huggingface/hub/models--openai--clip-vit-large-patch14/snapshots/32bd64288804d66eefd0ccbe215aa642df71cc41"
90
+
91
+ if local_tokenizer_path is not None:
92
+ self.tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path)
93
+ self.model = CLIPTextModel.from_pretrained(local_tokenizer_path)
94
+ else:
95
+ self.tokenizer = CLIPTokenizer.from_pretrained(model_name)
96
+ self.model = CLIPTextModel.from_pretrained(model_name, from_flax=True)
97
+ self.embedding_dim = self.model.config.hidden_size
98
+
99
+ def encode(self, sentences, device='cpu'):
100
+ self.model.to(device)
101
+ inputs = self.tokenizer(sentences, return_tensors='pt', padding=True, truncation=True).to(device)
102
+ input_ids = inputs['input_ids']
103
+ eos_token_id = self.tokenizer.eos_token_id
104
+ pad_token_id = self.tokenizer.pad_token_id
105
+
106
+ outputs = self.model(**inputs)
107
+ last_hidden = outputs.last_hidden_state # [B, T, D]
108
+
109
+ batch_size = input_ids.size(0)
110
+ embeddings = []
111
+
112
+ for i in range(batch_size):
113
+ input_seq = input_ids[i]
114
+ eos_positions = (input_seq == eos_token_id).nonzero(as_tuple=True)[0]
115
+
116
+ if len(eos_positions) > 0:
117
+ eos_idx = eos_positions[-1] # take last EOS (safe for duplicates)
118
+ else:
119
+ eos_idx = (input_seq != pad_token_id).sum() - 1 # fallback to last non-padding token
120
+
121
+ embeddings.append(last_hidden[i, eos_idx, :])
122
+
123
+ return torch.stack(embeddings)
124
+
125
+ class E5TextEncoder(BaseTextEncoder):
126
+ def __init__(self, model_name='intfloat/e5-base'):
127
+ super().__init__(model_name)
128
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
129
+ self.model = AutoModel.from_pretrained(model_name)
130
+ self.model.pooler = None
131
+ self.embedding_dim = self.model.config.hidden_size
132
+
133
+ def encode(self, sentences, device='cpu'):
134
+ self.model.to(device)
135
+ sentences = [f"query: {s}" for s in sentences] # official prompt for e5 (for features as per documentation)
136
+ inputs = self.tokenizer(sentences, return_tensors='pt', padding=True, truncation=True).to(device)
137
+ outputs = self.model(**inputs)
138
+ return average_pool(outputs.last_hidden_state, inputs['attention_mask'])
139
+
140
+ class GritLMTextEncoder(BaseTextEncoder):
141
+ def __init__(self, model_name='nomic-ai/nomic-bert-base-punc'):
142
+ super().__init__(model_name)
143
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
144
+ self.model = AutoModel.from_pretrained(model_name)
145
+ self.embedding_dim = self.model.config.hidden_size
146
+ self.proj_head = nn.Linear(self.embedding_dim, 768) # to match other encoders
147
+
148
+ def encode(self, sentences, device='cpu'):
149
+ self.model.to(device)
150
+ inputs = self.tokenizer(sentences, return_tensors='pt', padding=True, truncation=True).to(device)
151
+ outputs = self.model(**inputs)
152
+ pooled = average_pool(outputs.last_hidden_state, inputs['attention_mask'])
153
+ return self.proj_head(pooled)
154
+
155
+
156
+ class TextEncoder(pl.LightningModule):
157
+ def __init__(self, taglist_vocab: List[tuple], tag_vocab: dict, model_name='bert'):
158
+ super().__init__()
159
+ self.taglist_vocab = taglist_vocab
160
+ self.tag_vocab = tag_vocab
161
+
162
+ model_name = model_name.lower()
163
+ encoder_map = {
164
+ 'bert': lambda: BertTextEncoder('bert-base-uncased'),
165
+ 'clip': lambda: CLIPTextEncoder('openai/clip-vit-large-patch14'),
166
+ 'e5': lambda: E5TextEncoder('intfloat/e5-base'),
167
+ 'gritlm': lambda: GritLMTextEncoder('nomic-ai/nomic-bert-base-punc')
168
+ }
169
+
170
+ if model_name not in encoder_map:
171
+ raise ValueError(f"Unsupported model_name: {model_name}. Choose from {list(encoder_map.keys())}")
172
+ print(f"Text backbone: {model_name}")
173
+ self.encoder = encoder_map[model_name]() # Instantiate the selected encoder
174
+ # self.embedding_dim = 768
175
+
176
+ def forward(self, taglist_tensor: torch.Tensor) -> torch.Tensor:
177
+ tag_indices = taglist_tensor.tolist()
178
+ tag_sentences = taglist_index_to_sentence(self.taglist_vocab, self.tag_vocab, tag_indices, subsample=True) # randomize subsampling tags
179
+ embeddings = self.encoder.encode(tag_sentences, device=self.device)
180
+ return embeddings
181
+
182
+ def encode_raw_text(self, raw_text: str) -> torch.Tensor:
183
+ """
184
+ Encode a single raw string into an embedding for queries
185
+ """
186
+ return self.encoder.encode([raw_text], device=self.device)[0]
187
+
188
+ def encode_batch(self, raw_texts: List[str]) -> torch.Tensor:
189
+ """
190
+ Encode a batch of raw strings into embeddings for queries
191
+ """
192
+ return self.encoder.encode(raw_texts, device=self.device)
193
+
194
+
195
+ # import torch
196
+ # from transformers import (
197
+ # AutoTokenizer, AutoModel,
198
+ # BertTokenizer, BertModel,
199
+ # CLIPTokenizer, CLIPTextModel
200
+ # )
201
+ # import torch.nn as nn
202
+ # import pytorch_lightning as pl
203
+ # from typing import List
204
+ # from abc import ABC, abstractmethod
205
+ # import random
206
+
207
+ # import os
208
+ # os.environ["TOKENIZERS_PARALLELISM"] = "false"
209
+
210
+
211
+ # def taglist_index_to_sentence(taglist_vocab, tag_vocab, taglist_indices, subsample: bool = True):
212
+ # """
213
+ # Convert a tensor or list of taglist indices to a list of tag sentences.
214
+ # Optionally, randomly shuffle and sample a subset of tags for each sentence.
215
+
216
+ # Args:
217
+ # taglist_vocab: List of tuples of tag IDs.
218
+ # tag_vocab: Dictionary mapping tag ID to tag string.
219
+ # taglist_indices: Tensor or list of indices into taglist_vocab.
220
+ # seed: Random seed for reproducibility.
221
+ # subsample: If True, randomly subsample tags in each sentence.
222
+
223
+ # Returns:
224
+ # tag_sentences: List of strings (tag sentences).
225
+ # """
226
+ # if isinstance(taglist_indices, torch.Tensor):
227
+ # taglist_indices = taglist_indices.view(-1).tolist()
228
+
229
+ # tag_sentences = []
230
+
231
+ # for idx in taglist_indices:
232
+ # tag_ids = taglist_vocab[idx]
233
+ # tags = [tag_vocab[tid].lower().replace('=', ' ') for tid in tag_ids]
234
+
235
+ # if subsample and len(tags) > 1:
236
+ # n_sample = random.randint(1, len(tags)) # Choose how many tags to keep
237
+ # tags = random.sample(tags, n_sample) # Sample without replacement
238
+
239
+ # random.shuffle(tags) # Randomize order
240
+ # sentence = ' '.join(tags)
241
+ # tag_sentences.append(sentence)
242
+
243
+ # return tag_sentences
244
+
245
+
246
+ # def average_pool(last_hidden_states, attention_mask):
247
+ # masked_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
248
+ # return masked_hidden.sum(dim=1) / attention_mask.sum(dim=1, keepdim=True)
249
+
250
+
251
+ # class BaseTextEncoder(nn.Module, ABC):
252
+ # def __init__(self, model_name: str):
253
+ # super().__init__()
254
+ # self.model_name = model_name
255
+ # self.tokenizer = None
256
+ # self.model = None
257
+ # self.embedding_dim = None
258
+
259
+ # @abstractmethod
260
+ # def encode(self, sentences: List[str], device: str = 'cpu') -> torch.Tensor:
261
+ # """
262
+ # Encode a list of sentences into a tensor of embeddings.
263
+ # Must be implemented by subclasses.
264
+ # """
265
+ # pass
266
+
267
+ # class BertTextEncoder(BaseTextEncoder):
268
+ # def __init__(self, model_name='bert-base-uncased'):
269
+ # super().__init__(model_name)
270
+ # self.tokenizer = BertTokenizer.from_pretrained(model_name)
271
+ # self.model = BertModel.from_pretrained(model_name)
272
+ # self.embedding_dim = self.model.config.hidden_size
273
+
274
+ # def encode(self, sentences, device='cpu'):
275
+ # self.model.to(device)
276
+ # inputs = self.tokenizer(sentences, return_tensors='pt', padding=True, truncation=True).to(device)
277
+ # return self.model(**inputs).pooler_output
278
+
279
+
280
+ # class CLIPTextEncoder(BaseTextEncoder):
281
+ # def __init__(self, model_name='openai/clip-vit-large-patch14'):
282
+ # super().__init__(model_name)
283
+ # self.tokenizer = CLIPTokenizer.from_pretrained(model_name)
284
+ # self.model = CLIPTextModel.from_pretrained(model_name)
285
+ # self.embedding_dim = self.model.config.hidden_size
286
+
287
+ # def encode(self, sentences, device='cpu'):
288
+ # self.model.to(device)
289
+ # inputs = self.tokenizer(sentences, return_tensors='pt', padding=True, truncation=True).to(device)
290
+ # input_ids = inputs['input_ids']
291
+ # eos_token_id = self.tokenizer.eos_token_id
292
+ # pad_token_id = self.tokenizer.pad_token_id
293
+
294
+ # outputs = self.model(**inputs)
295
+ # last_hidden = outputs.last_hidden_state # [B, T, D]
296
+
297
+ # batch_size = input_ids.size(0)
298
+ # embeddings = []
299
+
300
+ # for i in range(batch_size):
301
+ # input_seq = input_ids[i]
302
+ # eos_positions = (input_seq == eos_token_id).nonzero(as_tuple=True)[0]
303
+
304
+ # if len(eos_positions) > 0:
305
+ # eos_idx = eos_positions[-1] # take last EOS (safe for duplicates)
306
+ # else:
307
+ # eos_idx = (input_seq != pad_token_id).sum() - 1 # fallback to last non-padding token
308
+
309
+ # embeddings.append(last_hidden[i, eos_idx, :])
310
+
311
+ # return torch.stack(embeddings)
312
+
313
+ # class E5TextEncoder(BaseTextEncoder):
314
+ # def __init__(self, model_name='intfloat/e5-base'):
315
+ # super().__init__(model_name)
316
+ # self.tokenizer = AutoTokenizer.from_pretrained(model_name)
317
+ # self.model = AutoModel.from_pretrained(model_name)
318
+ # self.model.pooler = None
319
+ # self.embedding_dim = self.model.config.hidden_size
320
+
321
+ # def encode(self, sentences, device='cpu'):
322
+ # self.model.to(device)
323
+ # sentences = [f"query: {s}" for s in sentences] # official prompt for e5 (for features as per documentation)
324
+ # inputs = self.tokenizer(sentences, return_tensors='pt', padding=True, truncation=True).to(device)
325
+ # outputs = self.model(**inputs)
326
+ # return average_pool(outputs.last_hidden_state, inputs['attention_mask'])
327
+
328
+ # class GritLMTextEncoder(BaseTextEncoder):
329
+ # def __init__(self, model_name='nomic-ai/nomic-bert-base-punc'):
330
+ # super().__init__(model_name)
331
+ # self.tokenizer = AutoTokenizer.from_pretrained(model_name)
332
+ # self.model = AutoModel.from_pretrained(model_name)
333
+ # self.embedding_dim = self.model.config.hidden_size
334
+ # self.proj_head = nn.Linear(self.embedding_dim, 768) # to match other encoders
335
+
336
+ # def encode(self, sentences, device='cpu'):
337
+ # self.model.to(device)
338
+ # inputs = self.tokenizer(sentences, return_tensors='pt', padding=True, truncation=True).to(device)
339
+ # outputs = self.model(**inputs)
340
+ # pooled = average_pool(outputs.last_hidden_state, inputs['attention_mask'])
341
+ # return self.proj_head(pooled)
342
+
343
+
344
+ # class TextEncoder(pl.LightningModule):
345
+ # def __init__(self, taglist_vocab: List[tuple], tag_vocab: dict, model_name='bert'):
346
+ # super().__init__()
347
+ # self.taglist_vocab = taglist_vocab
348
+ # self.tag_vocab = tag_vocab
349
+
350
+ # model_name = model_name.lower()
351
+ # encoder_map = {
352
+ # 'bert': lambda: BertTextEncoder('bert-base-uncased'),
353
+ # 'clip': lambda: CLIPTextEncoder('openai/clip-vit-large-patch14'),
354
+ # 'e5': lambda: E5TextEncoder('intfloat/e5-base'),
355
+ # 'gritlm': lambda: GritLMTextEncoder('nomic-ai/nomic-bert-base-punc')
356
+ # }
357
+
358
+ # if model_name not in encoder_map:
359
+ # raise ValueError(f"Unsupported model_name: {model_name}. Choose from {list(encoder_map.keys())}")
360
+ # print(f"Text backbone: {model_name}")
361
+ # self.encoder = encoder_map[model_name]() # Instantiate the selected encoder
362
+ # # self.embedding_dim = 768
363
+
364
+ # def forward(self, taglist_tensor: torch.Tensor) -> torch.Tensor:
365
+ # tag_indices = taglist_tensor.tolist()
366
+ # tag_sentences = taglist_index_to_sentence(self.taglist_vocab, self.tag_vocab, tag_indices, subsample=True) # randomize subsampling tags
367
+ # embeddings = self.encoder.encode(tag_sentences, device=self.device)
368
+ # return embeddings
369
+
370
+ # def encode_raw_text(self, raw_text: str) -> torch.Tensor:
371
+ # """
372
+ # Encode a single raw string into an embedding for queries
373
+ # """
374
+ # return self.encoder.encode([raw_text], device=self.device)[0]