Upload folder using huggingface_hub
Browse files- cosa/compute_embeddings.py +148 -0
- cosa/cosa.ckpt +3 -0
- cosa/model.py +290 -0
- cosa/text_encoder.py +374 -0
cosa/compute_embeddings.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import argparse
|
| 3 |
+
import torch
|
| 4 |
+
from tqdm import tqdm
|
| 5 |
+
from transformers import (
|
| 6 |
+
AutoTokenizer, AutoModel,
|
| 7 |
+
BertTokenizer, BertModel,
|
| 8 |
+
CLIPTokenizer, CLIPTextModel,
|
| 9 |
+
T5Tokenizer, T5EncoderModel
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
import sys
|
| 13 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "osm_clip")))
|
| 14 |
+
from model import OSMBind
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def average_pool(last_hidden_states, attention_mask):
|
| 18 |
+
"""Computes average pooling of hidden states, masking padding tokens."""
|
| 19 |
+
masked_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
|
| 20 |
+
return masked_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def get_tokenizer_and_model(encoder_type='bert', checkpoint_path=None, taglist_path = None, tagvocab_path = None):
|
| 24 |
+
if encoder_type == 'bert':
|
| 25 |
+
model_name = 'bert-base-uncased'
|
| 26 |
+
tokenizer = BertTokenizer.from_pretrained(model_name)
|
| 27 |
+
model = BertModel.from_pretrained(model_name)
|
| 28 |
+
embedding_fn = lambda outputs, batch_dict: outputs.pooler_output.squeeze()
|
| 29 |
+
|
| 30 |
+
elif encoder_type == 'clip':
|
| 31 |
+
model_name = 'openai/clip-vit-large-patch14'
|
| 32 |
+
tokenizer = CLIPTokenizer.from_pretrained(model_name)
|
| 33 |
+
model = CLIPTextModel.from_pretrained(model_name)
|
| 34 |
+
|
| 35 |
+
def clip_embedding_fn(outputs, batch_dict):
|
| 36 |
+
input_ids = batch_dict['input_ids']
|
| 37 |
+
eos_token_id = tokenizer.eos_token_id
|
| 38 |
+
seq_lengths = (input_ids == eos_token_id).nonzero(as_tuple=True)[1]
|
| 39 |
+
|
| 40 |
+
embeddings = []
|
| 41 |
+
for i in range(input_ids.size(0)):
|
| 42 |
+
eos_pos = seq_lengths[i] if i < len(seq_lengths) else (input_ids[i] != tokenizer.pad_token_id).sum() - 1
|
| 43 |
+
embeddings.append(outputs.last_hidden_state[i, eos_pos, :])
|
| 44 |
+
return torch.stack(embeddings)
|
| 45 |
+
|
| 46 |
+
embedding_fn = clip_embedding_fn
|
| 47 |
+
|
| 48 |
+
elif encoder_type == 'e5':
|
| 49 |
+
model_name = 'intfloat/e5-base-v2'
|
| 50 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 51 |
+
model = AutoModel.from_pretrained(model_name)
|
| 52 |
+
embedding_fn = lambda outputs, batch_dict: average_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
|
| 53 |
+
|
| 54 |
+
elif encoder_type == 't5':
|
| 55 |
+
model_name = 't5-base'
|
| 56 |
+
tokenizer = T5Tokenizer.from_pretrained(model_name)
|
| 57 |
+
model = T5EncoderModel.from_pretrained(model_name)
|
| 58 |
+
embedding_fn = lambda outputs, batch_dict: average_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
|
| 59 |
+
|
| 60 |
+
elif 'osm' in encoder_type:
|
| 61 |
+
text_backbone = encoder_type.split('-')[1] if '-' in encoder_type else 'clip'
|
| 62 |
+
model = OSMBind(taglist_path=taglist_path, tagvocab_path=tagvocab_path, text_backbone=text_backbone)
|
| 63 |
+
ckpt = torch.load(checkpoint_path, map_location='cpu')
|
| 64 |
+
model.load_state_dict(ckpt['state_dict'], strict=False)
|
| 65 |
+
model.eval().cuda()
|
| 66 |
+
tokenizer = None
|
| 67 |
+
|
| 68 |
+
def osm_embedding_fn(outputs, batch_dict):
|
| 69 |
+
return model.text_encoder.encode_batch(batch_dict['sentences'])
|
| 70 |
+
|
| 71 |
+
embedding_fn = osm_embedding_fn
|
| 72 |
+
|
| 73 |
+
else:
|
| 74 |
+
raise ValueError(f"Unsupported encoder_type: {encoder_type}")
|
| 75 |
+
|
| 76 |
+
model.eval()
|
| 77 |
+
return tokenizer, model, embedding_fn
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def generate_embeddings(taglist_path, tag_vocab_path, output_path,
|
| 81 |
+
encoder_type='bert', checkpoint_path=None):
|
| 82 |
+
# Load taglist and vocab
|
| 83 |
+
taglist = torch.load(taglist_path, weights_only = True) # list of tuples of tag indices
|
| 84 |
+
tag_vocab = torch.load(tag_vocab_path, weights_only = True)
|
| 85 |
+
tag_index = {v: k for k, v in tag_vocab.items()} # index -> tag string
|
| 86 |
+
|
| 87 |
+
# Convert taglist tuples to "sentences" of tag strings
|
| 88 |
+
sentences = []
|
| 89 |
+
for tl in taglist:
|
| 90 |
+
words = [tag_index[idx] for idx in tl]
|
| 91 |
+
sentences.append(" ".join(words))
|
| 92 |
+
|
| 93 |
+
# Optional prompt formatting
|
| 94 |
+
if encoder_type == 'e5':
|
| 95 |
+
sentences = [f"query: {s}" for s in sentences]
|
| 96 |
+
elif encoder_type == 't5':
|
| 97 |
+
sentences = [f"embedding: {s}" for s in sentences]
|
| 98 |
+
|
| 99 |
+
# Load model
|
| 100 |
+
tokenizer, model, embedding_fn = get_tokenizer_and_model(encoder_type, checkpoint_path, taglist_path = taglist_path, tagvocab_path = tag_vocab_path)
|
| 101 |
+
device = next(model.parameters()).device if hasattr(model, 'parameters') else torch.device('cpu')
|
| 102 |
+
|
| 103 |
+
# Generate embeddings
|
| 104 |
+
embeddings = []
|
| 105 |
+
print("Encoding taglists...")
|
| 106 |
+
for sentence in tqdm(sentences):
|
| 107 |
+
if 'osm' in encoder_type:
|
| 108 |
+
batch_dict = {'sentences': [sentence]}
|
| 109 |
+
outputs = None
|
| 110 |
+
else:
|
| 111 |
+
inputs = tokenizer([sentence], return_tensors='pt', padding=True, truncation=True)
|
| 112 |
+
batch_dict = {k: v.to(device) for k, v in inputs.items()}
|
| 113 |
+
outputs = model(**batch_dict)
|
| 114 |
+
|
| 115 |
+
with torch.inference_mode():
|
| 116 |
+
emb = embedding_fn(outputs, batch_dict)
|
| 117 |
+
if emb.ndim == 1:
|
| 118 |
+
emb = emb.unsqueeze(0)
|
| 119 |
+
embeddings.append(emb.cpu())
|
| 120 |
+
|
| 121 |
+
embeddings = torch.cat(embeddings, dim=0)
|
| 122 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
| 123 |
+
torch.save(embeddings, output_path)
|
| 124 |
+
print(f"Saved {len(sentences)} taglist embeddings to {output_path}")
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
# ========================
|
| 128 |
+
# Command Line Interface
|
| 129 |
+
# ========================
|
| 130 |
+
if __name__ == "__main__":
|
| 131 |
+
parser = argparse.ArgumentParser(description="Generate embeddings for taglists")
|
| 132 |
+
parser.add_argument("--taglist_path", type=str, required=True, help="Path to taglist_vocab.pt")
|
| 133 |
+
parser.add_argument("--tag_vocab_path", type=str, required=True, help="Path to tag_vocab.pt")
|
| 134 |
+
parser.add_argument("--output_path", type=str, required=True, help="Path to save embeddings tensor")
|
| 135 |
+
parser.add_argument("--encoder_type", type=str,
|
| 136 |
+
choices=["bert", "clip", "e5", "t5", "osm-clip", "osm-e5", "osm-bert"],
|
| 137 |
+
default="bert")
|
| 138 |
+
parser.add_argument("--checkpoint_path", type=str, default=None, help="Optional checkpoint for OSMBind")
|
| 139 |
+
|
| 140 |
+
args = parser.parse_args()
|
| 141 |
+
|
| 142 |
+
generate_embeddings(
|
| 143 |
+
taglist_path=args.taglist_path,
|
| 144 |
+
tag_vocab_path=args.tag_vocab_path,
|
| 145 |
+
output_path=args.output_path,
|
| 146 |
+
encoder_type=args.encoder_type,
|
| 147 |
+
checkpoint_path=args.checkpoint_path
|
| 148 |
+
)
|
cosa/cosa.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:981a8ec6c089d019dbe54afd34693d3617db8b28837cf5adf013702563b6f73a
|
| 3 |
+
size 2365975368
|
cosa/model.py
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import os
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import pytorch_lightning as pl
|
| 7 |
+
from datasets import OSMDataset
|
| 8 |
+
from torch.utils.data import DataLoader
|
| 9 |
+
import random
|
| 10 |
+
from typing import Optional, List, Tuple, Literal
|
| 11 |
+
from image_encoder import SatlasPretrainEncoder
|
| 12 |
+
from text_encoder import TextEncoder
|
| 13 |
+
from orthogonal_adamw import OrthogonalAdamW
|
| 14 |
+
from configs.config_e5 import config
|
| 15 |
+
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
|
| 16 |
+
from lightning.pytorch.loggers import WandbLogger
|
| 17 |
+
from utils import generate_tag_poly_pairs
|
| 18 |
+
import matplotlib.pyplot as plt
|
| 19 |
+
import io
|
| 20 |
+
import wandb
|
| 21 |
+
from PIL import Image
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# This performs a typical InfoNCE loss
|
| 25 |
+
def contrastive_loss(image_feats: torch.Tensor, text_feats: torch.Tensor, logit_scale: torch.Tensor) -> torch.Tensor:
|
| 26 |
+
logits = torch.matmul(image_feats, text_feats.t()) * logit_scale
|
| 27 |
+
labels = torch.arange(logits.size(0), device=logits.device)
|
| 28 |
+
|
| 29 |
+
return F.cross_entropy(logits, labels), logits
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class OSMBind(pl.LightningModule):
|
| 33 |
+
def __init__(self, train_dataset=None, val_dataset=None, **kwargs):
|
| 34 |
+
super().__init__()
|
| 35 |
+
self.train_dataset = train_dataset
|
| 36 |
+
self.val_dataset = val_dataset
|
| 37 |
+
|
| 38 |
+
self.image_encoder = SatlasPretrainEncoder(fpn=True, model_name="Aerial_SwinB_SI",
|
| 39 |
+
out_dim=768, num_extra_fpn_layers=4)
|
| 40 |
+
taglist_vocab = torch.load(kwargs.get("taglist_path"), weights_only = True)
|
| 41 |
+
tag_vocab_inverted = torch.load(kwargs.get("tagvocab_path"), weights_only = True) # str -> int
|
| 42 |
+
tag_vocab = {v: k for k, v in tag_vocab_inverted.items()} # int -> str
|
| 43 |
+
self.text_encoder = TextEncoder(taglist_vocab, tag_vocab,
|
| 44 |
+
model_name=kwargs.get("text_backbone"))
|
| 45 |
+
# for param in self.text_encoder.parameters():
|
| 46 |
+
# param.requires_grad = False
|
| 47 |
+
|
| 48 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) # softer scale for misaligned encoders
|
| 49 |
+
|
| 50 |
+
self.batch_size = kwargs.get("batch_size")
|
| 51 |
+
self.num_workers = kwargs.get("num_workers")
|
| 52 |
+
self.lr = kwargs.get("lr", 1e-4)
|
| 53 |
+
self.num_samples = kwargs.get("num_samples") # number of OSM classes sampled
|
| 54 |
+
self.ort_grad = kwargs.get("ort_grad")
|
| 55 |
+
|
| 56 |
+
def forward(self, sat_img: torch.Tensor, pixel_tensor: torch.Tensor):
|
| 57 |
+
full_image_feats = self.image_encoder(sat_img) # [B, D, H', W']
|
| 58 |
+
sampled_tag_tensor, image_poly_feats = generate_tag_poly_pairs(pixel_tensor, full_image_feats, K=self.num_samples) # [K], [K, D]
|
| 59 |
+
text_sampled_feats = self.text_encoder(sampled_tag_tensor) # [K, D]
|
| 60 |
+
|
| 61 |
+
return image_poly_feats, text_sampled_feats # [K, D], [K, D]
|
| 62 |
+
|
| 63 |
+
def shared_step(self, batch):
|
| 64 |
+
sat_img, pixel_tensor = batch
|
| 65 |
+
image_poly_feats, text_sampled_feats = self(sat_img, pixel_tensor) # [K, D], [K, D]
|
| 66 |
+
|
| 67 |
+
# contrastive loss for whole batch
|
| 68 |
+
image_feats_norm = F.normalize(image_poly_feats, dim=1)
|
| 69 |
+
text_feats_norm = F.normalize(text_sampled_feats, dim=1)
|
| 70 |
+
logit_scale = self.logit_scale.exp()
|
| 71 |
+
loss, logits = contrastive_loss(image_feats_norm, text_feats_norm,
|
| 72 |
+
logit_scale=logit_scale)
|
| 73 |
+
return loss, logits
|
| 74 |
+
|
| 75 |
+
def log_similarity_matrix(self, logits):
|
| 76 |
+
mat = logits.detach().cpu().numpy()
|
| 77 |
+
fig, ax = plt.subplots(figsize=(6,6))
|
| 78 |
+
cax = ax.matshow(mat, cmap="viridis")
|
| 79 |
+
fig.colorbar(cax)
|
| 80 |
+
ax.set_xlabel("Text samples")
|
| 81 |
+
ax.set_ylabel("Image samples")
|
| 82 |
+
ax.set_title("Similarity Matrix")
|
| 83 |
+
|
| 84 |
+
buf = io.BytesIO()
|
| 85 |
+
plt.savefig(buf, format='png')
|
| 86 |
+
buf.seek(0)
|
| 87 |
+
plt.close(fig)
|
| 88 |
+
|
| 89 |
+
# ✅ Fix: Convert buffer to PIL Image
|
| 90 |
+
image = Image.open(buf)
|
| 91 |
+
|
| 92 |
+
if isinstance(self.logger, WandbLogger):
|
| 93 |
+
self.logger.experiment.log({
|
| 94 |
+
"similarity_matrix": wandb.Image(image),
|
| 95 |
+
"global_step": self.global_step
|
| 96 |
+
})
|
| 97 |
+
|
| 98 |
+
def training_step(self, batch, batch_idx):
|
| 99 |
+
loss, logits = self.shared_step(batch)
|
| 100 |
+
self.log('train_loss', loss, sync_dist=True, prog_bar=True, on_epoch=True, batch_size=self.batch_size)
|
| 101 |
+
self.log('temperature', self.logit_scale.exp().item(), prog_bar=True, on_epoch=True)
|
| 102 |
+
if self.global_step % 500 == 0:
|
| 103 |
+
self.log_similarity_matrix(logits)
|
| 104 |
+
# Log histogram of similarity scores every step
|
| 105 |
+
if self.logger and hasattr(self.logger.experiment, "log"):
|
| 106 |
+
self.logger.experiment.log({"logits_hist": wandb.Histogram(logits.detach().cpu().numpy())})
|
| 107 |
+
|
| 108 |
+
# Optionally log mean and max of logits for monitoring
|
| 109 |
+
self.log("logits_mean", logits.mean(), on_step=True, on_epoch=False, prog_bar=True)
|
| 110 |
+
self.log("logits_max", logits.max(), on_step=True, on_epoch=False, prog_bar=True)
|
| 111 |
+
return loss
|
| 112 |
+
|
| 113 |
+
def on_train_batch_end(self, outputs, batch, batch_idx):
|
| 114 |
+
min_log_scale = np.log(1 / 1.0)
|
| 115 |
+
max_log_scale = np.log(1 / 0.01)
|
| 116 |
+
self.logit_scale.data.clamp_(min_log_scale, max_log_scale)
|
| 117 |
+
|
| 118 |
+
def on_after_backward(self):
|
| 119 |
+
if self.global_rank == 0 and self.current_epoch == 0:
|
| 120 |
+
for name, param in self.named_parameters():
|
| 121 |
+
if param.requires_grad and param.grad is None:
|
| 122 |
+
print(f"⚠️ Unused parameter: {name}")
|
| 123 |
+
|
| 124 |
+
def validation_step(self, batch, batch_idx):
|
| 125 |
+
loss, _ = self.shared_step(batch)
|
| 126 |
+
self.log('val_loss', loss, sync_dist=True, prog_bar=True, on_epoch=True, batch_size=self.batch_size)
|
| 127 |
+
return loss
|
| 128 |
+
|
| 129 |
+
def train_dataloader(self):
|
| 130 |
+
if self.train_dataset is None:
|
| 131 |
+
raise ValueError("This model was initialized without a training dataset.")
|
| 132 |
+
return DataLoader(self.train_dataset,
|
| 133 |
+
batch_size=self.batch_size,
|
| 134 |
+
num_workers=self.num_workers,
|
| 135 |
+
shuffle=True,
|
| 136 |
+
persistent_workers=False)
|
| 137 |
+
|
| 138 |
+
def val_dataloader(self):
|
| 139 |
+
if self.val_dataset is None:
|
| 140 |
+
raise ValueError("This model was initialized without a validation dataset.")
|
| 141 |
+
return DataLoader(self.val_dataset,
|
| 142 |
+
batch_size=self.batch_size,
|
| 143 |
+
num_workers=self.num_workers,
|
| 144 |
+
shuffle=False,
|
| 145 |
+
persistent_workers=False)
|
| 146 |
+
|
| 147 |
+
def configure_optimizers(self):
|
| 148 |
+
params = self.parameters()
|
| 149 |
+
if self.ort_grad:
|
| 150 |
+
self.optim = OrthogonalAdamW(
|
| 151 |
+
params,
|
| 152 |
+
lr=self.lr,
|
| 153 |
+
betas=(0.9, 0.98),
|
| 154 |
+
beta_ort=0.9,
|
| 155 |
+
eps=1e-6,
|
| 156 |
+
weight_decay=0.01
|
| 157 |
+
)
|
| 158 |
+
else:
|
| 159 |
+
self.optim = torch.optim.AdamW(
|
| 160 |
+
params,
|
| 161 |
+
lr=self.lr,
|
| 162 |
+
betas=(0.9, 0.98),
|
| 163 |
+
eps=1e-6,
|
| 164 |
+
weight_decay=0.01
|
| 165 |
+
)
|
| 166 |
+
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
|
| 167 |
+
optimizer=self.optim,
|
| 168 |
+
T_0=20
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
return [self.optim], [self.scheduler]
|
| 172 |
+
|
| 173 |
+
def sim_map_inf(self, sat_image: torch.Tensor, raw_text: str) -> torch.Tensor:
|
| 174 |
+
"""
|
| 175 |
+
Args:
|
| 176 |
+
sat_image: [1, 3, 512, 512] tensor (already normalized)
|
| 177 |
+
raw_text: str, e.g., "building"
|
| 178 |
+
|
| 179 |
+
Returns:
|
| 180 |
+
sim_map: [512, 512] similarity map between image and text embedding
|
| 181 |
+
"""
|
| 182 |
+
assert sat_image.dim() == 4 and sat_image.size(0) == 1, "Expected input of shape [1, 3, H, W]"
|
| 183 |
+
|
| 184 |
+
# Step 1: Extract spatial features
|
| 185 |
+
with torch.no_grad():
|
| 186 |
+
# image features
|
| 187 |
+
feat_map = self.image_encoder(sat_image) # [1, D, H', W']
|
| 188 |
+
feat_map = feat_map.squeeze(0) # [D, H', W']
|
| 189 |
+
feat_map_upsampled = F.interpolate(feat_map.unsqueeze(0), size=(512, 512), mode='bilinear', align_corners=False).squeeze(0) # [D, 512, 512]
|
| 190 |
+
feat_map_upsampled = F.normalize(feat_map_upsampled, dim=0) # [D, 512, 512]
|
| 191 |
+
|
| 192 |
+
# text features
|
| 193 |
+
text_feat = self.text_encoder.encode_raw_text(raw_text)
|
| 194 |
+
|
| 195 |
+
# cosine sim
|
| 196 |
+
text_feat = F.normalize(text_feat, dim=0)
|
| 197 |
+
feat_map_upsampled = F.normalize(feat_map_upsampled, dim=0)
|
| 198 |
+
sim_map = torch.einsum('chw,c->hw', feat_map_upsampled, text_feat) # [512, 512]
|
| 199 |
+
|
| 200 |
+
return sim_map
|
| 201 |
+
|
| 202 |
+
def encode_text(self, text: str) -> torch.Tensor:
|
| 203 |
+
with torch.no_grad():
|
| 204 |
+
return self.text_encoder.encode_raw_text(text)
|
| 205 |
+
|
| 206 |
+
def encode_image(self, image: torch.Tensor) -> torch.Tensor:
|
| 207 |
+
with torch.no_grad():
|
| 208 |
+
return self.image_encoder(image)
|
| 209 |
+
|
| 210 |
+
def seed_everything(seed=42):
|
| 211 |
+
"""
|
| 212 |
+
seed: int
|
| 213 |
+
"""
|
| 214 |
+
torch.manual_seed(seed)
|
| 215 |
+
torch.cuda.manual_seed_all(seed)
|
| 216 |
+
np.random.seed(seed)
|
| 217 |
+
random.seed(seed)
|
| 218 |
+
torch.backends.cudnn.deterministic = True
|
| 219 |
+
torch.backends.cudnn.benchmark = False
|
| 220 |
+
os.environ["PYTHONHASHSEED"] = str(seed)
|
| 221 |
+
|
| 222 |
+
if __name__=='__main__':
|
| 223 |
+
import warnings
|
| 224 |
+
warnings.filterwarnings("ignore")
|
| 225 |
+
torch.set_warn_always(False)
|
| 226 |
+
|
| 227 |
+
seed_everything()
|
| 228 |
+
train_dataset = OSMDataset(metadata_path = config.train_csv,
|
| 229 |
+
image_dir=config.sat_img_dir,
|
| 230 |
+
pixel_tensor_dir=config.pixel_tensors_dir,
|
| 231 |
+
mode='train')
|
| 232 |
+
val_dataset = OSMDataset(metadata_path = config.val_csv,
|
| 233 |
+
image_dir=config.sat_img_dir,
|
| 234 |
+
pixel_tensor_dir=config.pixel_tensors_dir,
|
| 235 |
+
mode='val')
|
| 236 |
+
|
| 237 |
+
# from torch.utils.data import Subset
|
| 238 |
+
# train_dataset = Subset(train_dataset, range(1000))
|
| 239 |
+
# val_dataset = Subset(val_dataset, range(200))
|
| 240 |
+
|
| 241 |
+
kwargs = {
|
| 242 |
+
'batch_size':config.batch_size,
|
| 243 |
+
'num_workers': config.num_workers,
|
| 244 |
+
'num_samples': config.num_contrastive_samples,
|
| 245 |
+
'ort_grad': config.ort_grad,
|
| 246 |
+
'lr': config.lr,
|
| 247 |
+
'taglist_vocab_path': config.taglist_vocab_path,
|
| 248 |
+
'tag_vocab_path': config.tag_vocab_path,
|
| 249 |
+
'text_backbone': config.text_backbone
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
model = OSMBind(train_dataset, val_dataset, **kwargs)
|
| 253 |
+
torch.cuda.empty_cache()
|
| 254 |
+
|
| 255 |
+
checkpoint_path = '/data/b.j.wei/rendersynth/osm_clip/checkpoints/osmclip_e5/osmclip_config_e5-epoch=39-val_loss=3.23.ckpt'
|
| 256 |
+
if checkpoint_path:
|
| 257 |
+
ckpt = torch.load(checkpoint_path, map_location='cpu')
|
| 258 |
+
model.load_state_dict(ckpt['state_dict'])
|
| 259 |
+
|
| 260 |
+
checkpoint = ModelCheckpoint(
|
| 261 |
+
monitor='val_loss',
|
| 262 |
+
dirpath=config.save_dir,
|
| 263 |
+
filename=config.filename,
|
| 264 |
+
mode='min',
|
| 265 |
+
save_top_k=1,
|
| 266 |
+
every_n_epochs=1
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
early_stop_callback = EarlyStopping(
|
| 270 |
+
monitor='val_loss',
|
| 271 |
+
patience=15,
|
| 272 |
+
mode='min'
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
logger = WandbLogger(project="osmclip",
|
| 276 |
+
name=f"{config.experiment_name}")
|
| 277 |
+
|
| 278 |
+
trainer = pl.Trainer(
|
| 279 |
+
accelerator='gpu',
|
| 280 |
+
devices=config.devices,
|
| 281 |
+
strategy='ddp',
|
| 282 |
+
max_epochs=config.max_epochs,
|
| 283 |
+
num_nodes=1,
|
| 284 |
+
callbacks=[checkpoint, early_stop_callback],
|
| 285 |
+
accumulate_grad_batches=config.accumulate_grad_batches,
|
| 286 |
+
log_every_n_steps=5,
|
| 287 |
+
logger = logger #wandb logger
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
trainer.fit(model)
|
cosa/text_encoder.py
ADDED
|
@@ -0,0 +1,374 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import (
|
| 3 |
+
AutoTokenizer, AutoModel,
|
| 4 |
+
BertTokenizer, BertModel,
|
| 5 |
+
CLIPTokenizer, CLIPTextModel
|
| 6 |
+
)
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import pytorch_lightning as pl
|
| 9 |
+
from typing import List
|
| 10 |
+
from abc import ABC, abstractmethod
|
| 11 |
+
import random
|
| 12 |
+
|
| 13 |
+
import os
|
| 14 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def taglist_index_to_sentence(taglist_vocab, tag_vocab, taglist_indices, subsample: bool = True):
|
| 18 |
+
"""
|
| 19 |
+
Convert a tensor or list of taglist indices to a list of tag sentences.
|
| 20 |
+
Optionally, randomly shuffle and sample a subset of tags for each sentence.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
taglist_vocab: List of tuples of tag IDs.
|
| 24 |
+
tag_vocab: Dictionary mapping tag ID to tag string.
|
| 25 |
+
taglist_indices: Tensor or list of indices into taglist_vocab.
|
| 26 |
+
seed: Random seed for reproducibility.
|
| 27 |
+
subsample: If True, randomly subsample tags in each sentence.
|
| 28 |
+
|
| 29 |
+
Returns:
|
| 30 |
+
tag_sentences: List of strings (tag sentences).
|
| 31 |
+
"""
|
| 32 |
+
if isinstance(taglist_indices, torch.Tensor):
|
| 33 |
+
taglist_indices = taglist_indices.view(-1).tolist()
|
| 34 |
+
|
| 35 |
+
tag_sentences = []
|
| 36 |
+
|
| 37 |
+
for idx in taglist_indices:
|
| 38 |
+
tag_ids = taglist_vocab[idx]
|
| 39 |
+
tags = [tag_vocab[tid].lower().replace('=', ' ') for tid in tag_ids]
|
| 40 |
+
|
| 41 |
+
if subsample and len(tags) > 1:
|
| 42 |
+
n_sample = random.randint(1, len(tags)) # Choose how many tags to keep
|
| 43 |
+
tags = random.sample(tags, n_sample) # Sample without replacement
|
| 44 |
+
|
| 45 |
+
random.shuffle(tags) # Randomize order
|
| 46 |
+
sentence = ' '.join(tags)
|
| 47 |
+
tag_sentences.append(sentence)
|
| 48 |
+
|
| 49 |
+
return tag_sentences
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def average_pool(last_hidden_states, attention_mask):
|
| 53 |
+
masked_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
|
| 54 |
+
return masked_hidden.sum(dim=1) / attention_mask.sum(dim=1, keepdim=True)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class BaseTextEncoder(nn.Module, ABC):
|
| 58 |
+
def __init__(self, model_name: str):
|
| 59 |
+
super().__init__()
|
| 60 |
+
self.model_name = model_name
|
| 61 |
+
self.tokenizer = None
|
| 62 |
+
self.model = None
|
| 63 |
+
self.embedding_dim = None
|
| 64 |
+
|
| 65 |
+
@abstractmethod
|
| 66 |
+
def encode(self, sentences: List[str], device: str = 'cpu') -> torch.Tensor:
|
| 67 |
+
"""
|
| 68 |
+
Encode a list of sentences into a tensor of embeddings.
|
| 69 |
+
Must be implemented by subclasses.
|
| 70 |
+
"""
|
| 71 |
+
pass
|
| 72 |
+
|
| 73 |
+
class BertTextEncoder(BaseTextEncoder):
|
| 74 |
+
def __init__(self, model_name='bert-base-uncased'):
|
| 75 |
+
super().__init__(model_name)
|
| 76 |
+
self.tokenizer = BertTokenizer.from_pretrained(model_name)
|
| 77 |
+
self.model = BertModel.from_pretrained(model_name)
|
| 78 |
+
self.embedding_dim = self.model.config.hidden_size
|
| 79 |
+
|
| 80 |
+
def encode(self, sentences, device='cpu'):
|
| 81 |
+
self.model.to(device)
|
| 82 |
+
inputs = self.tokenizer(sentences, return_tensors='pt', padding=True, truncation=True).to(device)
|
| 83 |
+
return self.model(**inputs).pooler_output
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class CLIPTextEncoder(BaseTextEncoder):
|
| 87 |
+
def __init__(self, model_name='openai/clip-vit-large-patch14', local_tokenizer_path=None):
|
| 88 |
+
super().__init__(model_name)
|
| 89 |
+
local_tokenizer_path = "/u/cherd/.cache/huggingface/hub/models--openai--clip-vit-large-patch14/snapshots/32bd64288804d66eefd0ccbe215aa642df71cc41"
|
| 90 |
+
|
| 91 |
+
if local_tokenizer_path is not None:
|
| 92 |
+
self.tokenizer = CLIPTokenizer.from_pretrained(local_tokenizer_path)
|
| 93 |
+
self.model = CLIPTextModel.from_pretrained(local_tokenizer_path)
|
| 94 |
+
else:
|
| 95 |
+
self.tokenizer = CLIPTokenizer.from_pretrained(model_name)
|
| 96 |
+
self.model = CLIPTextModel.from_pretrained(model_name, from_flax=True)
|
| 97 |
+
self.embedding_dim = self.model.config.hidden_size
|
| 98 |
+
|
| 99 |
+
def encode(self, sentences, device='cpu'):
|
| 100 |
+
self.model.to(device)
|
| 101 |
+
inputs = self.tokenizer(sentences, return_tensors='pt', padding=True, truncation=True).to(device)
|
| 102 |
+
input_ids = inputs['input_ids']
|
| 103 |
+
eos_token_id = self.tokenizer.eos_token_id
|
| 104 |
+
pad_token_id = self.tokenizer.pad_token_id
|
| 105 |
+
|
| 106 |
+
outputs = self.model(**inputs)
|
| 107 |
+
last_hidden = outputs.last_hidden_state # [B, T, D]
|
| 108 |
+
|
| 109 |
+
batch_size = input_ids.size(0)
|
| 110 |
+
embeddings = []
|
| 111 |
+
|
| 112 |
+
for i in range(batch_size):
|
| 113 |
+
input_seq = input_ids[i]
|
| 114 |
+
eos_positions = (input_seq == eos_token_id).nonzero(as_tuple=True)[0]
|
| 115 |
+
|
| 116 |
+
if len(eos_positions) > 0:
|
| 117 |
+
eos_idx = eos_positions[-1] # take last EOS (safe for duplicates)
|
| 118 |
+
else:
|
| 119 |
+
eos_idx = (input_seq != pad_token_id).sum() - 1 # fallback to last non-padding token
|
| 120 |
+
|
| 121 |
+
embeddings.append(last_hidden[i, eos_idx, :])
|
| 122 |
+
|
| 123 |
+
return torch.stack(embeddings)
|
| 124 |
+
|
| 125 |
+
class E5TextEncoder(BaseTextEncoder):
|
| 126 |
+
def __init__(self, model_name='intfloat/e5-base'):
|
| 127 |
+
super().__init__(model_name)
|
| 128 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 129 |
+
self.model = AutoModel.from_pretrained(model_name)
|
| 130 |
+
self.model.pooler = None
|
| 131 |
+
self.embedding_dim = self.model.config.hidden_size
|
| 132 |
+
|
| 133 |
+
def encode(self, sentences, device='cpu'):
|
| 134 |
+
self.model.to(device)
|
| 135 |
+
sentences = [f"query: {s}" for s in sentences] # official prompt for e5 (for features as per documentation)
|
| 136 |
+
inputs = self.tokenizer(sentences, return_tensors='pt', padding=True, truncation=True).to(device)
|
| 137 |
+
outputs = self.model(**inputs)
|
| 138 |
+
return average_pool(outputs.last_hidden_state, inputs['attention_mask'])
|
| 139 |
+
|
| 140 |
+
class GritLMTextEncoder(BaseTextEncoder):
|
| 141 |
+
def __init__(self, model_name='nomic-ai/nomic-bert-base-punc'):
|
| 142 |
+
super().__init__(model_name)
|
| 143 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 144 |
+
self.model = AutoModel.from_pretrained(model_name)
|
| 145 |
+
self.embedding_dim = self.model.config.hidden_size
|
| 146 |
+
self.proj_head = nn.Linear(self.embedding_dim, 768) # to match other encoders
|
| 147 |
+
|
| 148 |
+
def encode(self, sentences, device='cpu'):
|
| 149 |
+
self.model.to(device)
|
| 150 |
+
inputs = self.tokenizer(sentences, return_tensors='pt', padding=True, truncation=True).to(device)
|
| 151 |
+
outputs = self.model(**inputs)
|
| 152 |
+
pooled = average_pool(outputs.last_hidden_state, inputs['attention_mask'])
|
| 153 |
+
return self.proj_head(pooled)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class TextEncoder(pl.LightningModule):
|
| 157 |
+
def __init__(self, taglist_vocab: List[tuple], tag_vocab: dict, model_name='bert'):
|
| 158 |
+
super().__init__()
|
| 159 |
+
self.taglist_vocab = taglist_vocab
|
| 160 |
+
self.tag_vocab = tag_vocab
|
| 161 |
+
|
| 162 |
+
model_name = model_name.lower()
|
| 163 |
+
encoder_map = {
|
| 164 |
+
'bert': lambda: BertTextEncoder('bert-base-uncased'),
|
| 165 |
+
'clip': lambda: CLIPTextEncoder('openai/clip-vit-large-patch14'),
|
| 166 |
+
'e5': lambda: E5TextEncoder('intfloat/e5-base'),
|
| 167 |
+
'gritlm': lambda: GritLMTextEncoder('nomic-ai/nomic-bert-base-punc')
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
if model_name not in encoder_map:
|
| 171 |
+
raise ValueError(f"Unsupported model_name: {model_name}. Choose from {list(encoder_map.keys())}")
|
| 172 |
+
print(f"Text backbone: {model_name}")
|
| 173 |
+
self.encoder = encoder_map[model_name]() # Instantiate the selected encoder
|
| 174 |
+
# self.embedding_dim = 768
|
| 175 |
+
|
| 176 |
+
def forward(self, taglist_tensor: torch.Tensor) -> torch.Tensor:
|
| 177 |
+
tag_indices = taglist_tensor.tolist()
|
| 178 |
+
tag_sentences = taglist_index_to_sentence(self.taglist_vocab, self.tag_vocab, tag_indices, subsample=True) # randomize subsampling tags
|
| 179 |
+
embeddings = self.encoder.encode(tag_sentences, device=self.device)
|
| 180 |
+
return embeddings
|
| 181 |
+
|
| 182 |
+
def encode_raw_text(self, raw_text: str) -> torch.Tensor:
|
| 183 |
+
"""
|
| 184 |
+
Encode a single raw string into an embedding for queries
|
| 185 |
+
"""
|
| 186 |
+
return self.encoder.encode([raw_text], device=self.device)[0]
|
| 187 |
+
|
| 188 |
+
def encode_batch(self, raw_texts: List[str]) -> torch.Tensor:
|
| 189 |
+
"""
|
| 190 |
+
Encode a batch of raw strings into embeddings for queries
|
| 191 |
+
"""
|
| 192 |
+
return self.encoder.encode(raw_texts, device=self.device)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
# import torch
|
| 196 |
+
# from transformers import (
|
| 197 |
+
# AutoTokenizer, AutoModel,
|
| 198 |
+
# BertTokenizer, BertModel,
|
| 199 |
+
# CLIPTokenizer, CLIPTextModel
|
| 200 |
+
# )
|
| 201 |
+
# import torch.nn as nn
|
| 202 |
+
# import pytorch_lightning as pl
|
| 203 |
+
# from typing import List
|
| 204 |
+
# from abc import ABC, abstractmethod
|
| 205 |
+
# import random
|
| 206 |
+
|
| 207 |
+
# import os
|
| 208 |
+
# os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
# def taglist_index_to_sentence(taglist_vocab, tag_vocab, taglist_indices, subsample: bool = True):
|
| 212 |
+
# """
|
| 213 |
+
# Convert a tensor or list of taglist indices to a list of tag sentences.
|
| 214 |
+
# Optionally, randomly shuffle and sample a subset of tags for each sentence.
|
| 215 |
+
|
| 216 |
+
# Args:
|
| 217 |
+
# taglist_vocab: List of tuples of tag IDs.
|
| 218 |
+
# tag_vocab: Dictionary mapping tag ID to tag string.
|
| 219 |
+
# taglist_indices: Tensor or list of indices into taglist_vocab.
|
| 220 |
+
# seed: Random seed for reproducibility.
|
| 221 |
+
# subsample: If True, randomly subsample tags in each sentence.
|
| 222 |
+
|
| 223 |
+
# Returns:
|
| 224 |
+
# tag_sentences: List of strings (tag sentences).
|
| 225 |
+
# """
|
| 226 |
+
# if isinstance(taglist_indices, torch.Tensor):
|
| 227 |
+
# taglist_indices = taglist_indices.view(-1).tolist()
|
| 228 |
+
|
| 229 |
+
# tag_sentences = []
|
| 230 |
+
|
| 231 |
+
# for idx in taglist_indices:
|
| 232 |
+
# tag_ids = taglist_vocab[idx]
|
| 233 |
+
# tags = [tag_vocab[tid].lower().replace('=', ' ') for tid in tag_ids]
|
| 234 |
+
|
| 235 |
+
# if subsample and len(tags) > 1:
|
| 236 |
+
# n_sample = random.randint(1, len(tags)) # Choose how many tags to keep
|
| 237 |
+
# tags = random.sample(tags, n_sample) # Sample without replacement
|
| 238 |
+
|
| 239 |
+
# random.shuffle(tags) # Randomize order
|
| 240 |
+
# sentence = ' '.join(tags)
|
| 241 |
+
# tag_sentences.append(sentence)
|
| 242 |
+
|
| 243 |
+
# return tag_sentences
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
# def average_pool(last_hidden_states, attention_mask):
|
| 247 |
+
# masked_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
|
| 248 |
+
# return masked_hidden.sum(dim=1) / attention_mask.sum(dim=1, keepdim=True)
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
# class BaseTextEncoder(nn.Module, ABC):
|
| 252 |
+
# def __init__(self, model_name: str):
|
| 253 |
+
# super().__init__()
|
| 254 |
+
# self.model_name = model_name
|
| 255 |
+
# self.tokenizer = None
|
| 256 |
+
# self.model = None
|
| 257 |
+
# self.embedding_dim = None
|
| 258 |
+
|
| 259 |
+
# @abstractmethod
|
| 260 |
+
# def encode(self, sentences: List[str], device: str = 'cpu') -> torch.Tensor:
|
| 261 |
+
# """
|
| 262 |
+
# Encode a list of sentences into a tensor of embeddings.
|
| 263 |
+
# Must be implemented by subclasses.
|
| 264 |
+
# """
|
| 265 |
+
# pass
|
| 266 |
+
|
| 267 |
+
# class BertTextEncoder(BaseTextEncoder):
|
| 268 |
+
# def __init__(self, model_name='bert-base-uncased'):
|
| 269 |
+
# super().__init__(model_name)
|
| 270 |
+
# self.tokenizer = BertTokenizer.from_pretrained(model_name)
|
| 271 |
+
# self.model = BertModel.from_pretrained(model_name)
|
| 272 |
+
# self.embedding_dim = self.model.config.hidden_size
|
| 273 |
+
|
| 274 |
+
# def encode(self, sentences, device='cpu'):
|
| 275 |
+
# self.model.to(device)
|
| 276 |
+
# inputs = self.tokenizer(sentences, return_tensors='pt', padding=True, truncation=True).to(device)
|
| 277 |
+
# return self.model(**inputs).pooler_output
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
# class CLIPTextEncoder(BaseTextEncoder):
|
| 281 |
+
# def __init__(self, model_name='openai/clip-vit-large-patch14'):
|
| 282 |
+
# super().__init__(model_name)
|
| 283 |
+
# self.tokenizer = CLIPTokenizer.from_pretrained(model_name)
|
| 284 |
+
# self.model = CLIPTextModel.from_pretrained(model_name)
|
| 285 |
+
# self.embedding_dim = self.model.config.hidden_size
|
| 286 |
+
|
| 287 |
+
# def encode(self, sentences, device='cpu'):
|
| 288 |
+
# self.model.to(device)
|
| 289 |
+
# inputs = self.tokenizer(sentences, return_tensors='pt', padding=True, truncation=True).to(device)
|
| 290 |
+
# input_ids = inputs['input_ids']
|
| 291 |
+
# eos_token_id = self.tokenizer.eos_token_id
|
| 292 |
+
# pad_token_id = self.tokenizer.pad_token_id
|
| 293 |
+
|
| 294 |
+
# outputs = self.model(**inputs)
|
| 295 |
+
# last_hidden = outputs.last_hidden_state # [B, T, D]
|
| 296 |
+
|
| 297 |
+
# batch_size = input_ids.size(0)
|
| 298 |
+
# embeddings = []
|
| 299 |
+
|
| 300 |
+
# for i in range(batch_size):
|
| 301 |
+
# input_seq = input_ids[i]
|
| 302 |
+
# eos_positions = (input_seq == eos_token_id).nonzero(as_tuple=True)[0]
|
| 303 |
+
|
| 304 |
+
# if len(eos_positions) > 0:
|
| 305 |
+
# eos_idx = eos_positions[-1] # take last EOS (safe for duplicates)
|
| 306 |
+
# else:
|
| 307 |
+
# eos_idx = (input_seq != pad_token_id).sum() - 1 # fallback to last non-padding token
|
| 308 |
+
|
| 309 |
+
# embeddings.append(last_hidden[i, eos_idx, :])
|
| 310 |
+
|
| 311 |
+
# return torch.stack(embeddings)
|
| 312 |
+
|
| 313 |
+
# class E5TextEncoder(BaseTextEncoder):
|
| 314 |
+
# def __init__(self, model_name='intfloat/e5-base'):
|
| 315 |
+
# super().__init__(model_name)
|
| 316 |
+
# self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 317 |
+
# self.model = AutoModel.from_pretrained(model_name)
|
| 318 |
+
# self.model.pooler = None
|
| 319 |
+
# self.embedding_dim = self.model.config.hidden_size
|
| 320 |
+
|
| 321 |
+
# def encode(self, sentences, device='cpu'):
|
| 322 |
+
# self.model.to(device)
|
| 323 |
+
# sentences = [f"query: {s}" for s in sentences] # official prompt for e5 (for features as per documentation)
|
| 324 |
+
# inputs = self.tokenizer(sentences, return_tensors='pt', padding=True, truncation=True).to(device)
|
| 325 |
+
# outputs = self.model(**inputs)
|
| 326 |
+
# return average_pool(outputs.last_hidden_state, inputs['attention_mask'])
|
| 327 |
+
|
| 328 |
+
# class GritLMTextEncoder(BaseTextEncoder):
|
| 329 |
+
# def __init__(self, model_name='nomic-ai/nomic-bert-base-punc'):
|
| 330 |
+
# super().__init__(model_name)
|
| 331 |
+
# self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 332 |
+
# self.model = AutoModel.from_pretrained(model_name)
|
| 333 |
+
# self.embedding_dim = self.model.config.hidden_size
|
| 334 |
+
# self.proj_head = nn.Linear(self.embedding_dim, 768) # to match other encoders
|
| 335 |
+
|
| 336 |
+
# def encode(self, sentences, device='cpu'):
|
| 337 |
+
# self.model.to(device)
|
| 338 |
+
# inputs = self.tokenizer(sentences, return_tensors='pt', padding=True, truncation=True).to(device)
|
| 339 |
+
# outputs = self.model(**inputs)
|
| 340 |
+
# pooled = average_pool(outputs.last_hidden_state, inputs['attention_mask'])
|
| 341 |
+
# return self.proj_head(pooled)
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
# class TextEncoder(pl.LightningModule):
|
| 345 |
+
# def __init__(self, taglist_vocab: List[tuple], tag_vocab: dict, model_name='bert'):
|
| 346 |
+
# super().__init__()
|
| 347 |
+
# self.taglist_vocab = taglist_vocab
|
| 348 |
+
# self.tag_vocab = tag_vocab
|
| 349 |
+
|
| 350 |
+
# model_name = model_name.lower()
|
| 351 |
+
# encoder_map = {
|
| 352 |
+
# 'bert': lambda: BertTextEncoder('bert-base-uncased'),
|
| 353 |
+
# 'clip': lambda: CLIPTextEncoder('openai/clip-vit-large-patch14'),
|
| 354 |
+
# 'e5': lambda: E5TextEncoder('intfloat/e5-base'),
|
| 355 |
+
# 'gritlm': lambda: GritLMTextEncoder('nomic-ai/nomic-bert-base-punc')
|
| 356 |
+
# }
|
| 357 |
+
|
| 358 |
+
# if model_name not in encoder_map:
|
| 359 |
+
# raise ValueError(f"Unsupported model_name: {model_name}. Choose from {list(encoder_map.keys())}")
|
| 360 |
+
# print(f"Text backbone: {model_name}")
|
| 361 |
+
# self.encoder = encoder_map[model_name]() # Instantiate the selected encoder
|
| 362 |
+
# # self.embedding_dim = 768
|
| 363 |
+
|
| 364 |
+
# def forward(self, taglist_tensor: torch.Tensor) -> torch.Tensor:
|
| 365 |
+
# tag_indices = taglist_tensor.tolist()
|
| 366 |
+
# tag_sentences = taglist_index_to_sentence(self.taglist_vocab, self.tag_vocab, tag_indices, subsample=True) # randomize subsampling tags
|
| 367 |
+
# embeddings = self.encoder.encode(tag_sentences, device=self.device)
|
| 368 |
+
# return embeddings
|
| 369 |
+
|
| 370 |
+
# def encode_raw_text(self, raw_text: str) -> torch.Tensor:
|
| 371 |
+
# """
|
| 372 |
+
# Encode a single raw string into an embedding for queries
|
| 373 |
+
# """
|
| 374 |
+
# return self.encoder.encode([raw_text], device=self.device)[0]
|