VectorSynth-COSA / cosa /compute_embeddings.py
dcher95's picture
Upload folder using huggingface_hub
9cabbed verified
import os
import argparse
import torch
from tqdm import tqdm
from transformers import (
AutoTokenizer, AutoModel,
BertTokenizer, BertModel,
CLIPTokenizer, CLIPTextModel,
T5Tokenizer, T5EncoderModel
)
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "osm_clip")))
from model import OSMBind
def average_pool(last_hidden_states, attention_mask):
"""Computes average pooling of hidden states, masking padding tokens."""
masked_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
return masked_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
def get_tokenizer_and_model(encoder_type='bert', checkpoint_path=None, taglist_path = None, tagvocab_path = None):
if encoder_type == 'bert':
model_name = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertModel.from_pretrained(model_name)
embedding_fn = lambda outputs, batch_dict: outputs.pooler_output.squeeze()
elif encoder_type == 'clip':
model_name = 'openai/clip-vit-large-patch14'
tokenizer = CLIPTokenizer.from_pretrained(model_name)
model = CLIPTextModel.from_pretrained(model_name)
def clip_embedding_fn(outputs, batch_dict):
input_ids = batch_dict['input_ids']
eos_token_id = tokenizer.eos_token_id
seq_lengths = (input_ids == eos_token_id).nonzero(as_tuple=True)[1]
embeddings = []
for i in range(input_ids.size(0)):
eos_pos = seq_lengths[i] if i < len(seq_lengths) else (input_ids[i] != tokenizer.pad_token_id).sum() - 1
embeddings.append(outputs.last_hidden_state[i, eos_pos, :])
return torch.stack(embeddings)
embedding_fn = clip_embedding_fn
elif encoder_type == 'e5':
model_name = 'intfloat/e5-base-v2'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
embedding_fn = lambda outputs, batch_dict: average_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
elif encoder_type == 't5':
model_name = 't5-base'
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5EncoderModel.from_pretrained(model_name)
embedding_fn = lambda outputs, batch_dict: average_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
elif 'osm' in encoder_type:
text_backbone = encoder_type.split('-')[1] if '-' in encoder_type else 'clip'
model = OSMBind(taglist_path=taglist_path, tagvocab_path=tagvocab_path, text_backbone=text_backbone)
ckpt = torch.load(checkpoint_path, map_location='cpu')
model.load_state_dict(ckpt['state_dict'], strict=False)
model.eval().cuda()
tokenizer = None
def osm_embedding_fn(outputs, batch_dict):
return model.text_encoder.encode_batch(batch_dict['sentences'])
embedding_fn = osm_embedding_fn
else:
raise ValueError(f"Unsupported encoder_type: {encoder_type}")
model.eval()
return tokenizer, model, embedding_fn
def generate_embeddings(taglist_path, tag_vocab_path, output_path,
encoder_type='bert', checkpoint_path=None):
# Load taglist and vocab
taglist = torch.load(taglist_path, weights_only = True) # list of tuples of tag indices
tag_vocab = torch.load(tag_vocab_path, weights_only = True)
tag_index = {v: k for k, v in tag_vocab.items()} # index -> tag string
# Convert taglist tuples to "sentences" of tag strings
sentences = []
for tl in taglist:
words = [tag_index[idx] for idx in tl]
sentences.append(" ".join(words))
# Optional prompt formatting
if encoder_type == 'e5':
sentences = [f"query: {s}" for s in sentences]
elif encoder_type == 't5':
sentences = [f"embedding: {s}" for s in sentences]
# Load model
tokenizer, model, embedding_fn = get_tokenizer_and_model(encoder_type, checkpoint_path, taglist_path = taglist_path, tagvocab_path = tag_vocab_path)
device = next(model.parameters()).device if hasattr(model, 'parameters') else torch.device('cpu')
# Generate embeddings
embeddings = []
print("Encoding taglists...")
for sentence in tqdm(sentences):
if 'osm' in encoder_type:
batch_dict = {'sentences': [sentence]}
outputs = None
else:
inputs = tokenizer([sentence], return_tensors='pt', padding=True, truncation=True)
batch_dict = {k: v.to(device) for k, v in inputs.items()}
outputs = model(**batch_dict)
with torch.inference_mode():
emb = embedding_fn(outputs, batch_dict)
if emb.ndim == 1:
emb = emb.unsqueeze(0)
embeddings.append(emb.cpu())
embeddings = torch.cat(embeddings, dim=0)
os.makedirs(os.path.dirname(output_path), exist_ok=True)
torch.save(embeddings, output_path)
print(f"Saved {len(sentences)} taglist embeddings to {output_path}")
# ========================
# Command Line Interface
# ========================
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Generate embeddings for taglists")
parser.add_argument("--taglist_path", type=str, required=True, help="Path to taglist_vocab.pt")
parser.add_argument("--tag_vocab_path", type=str, required=True, help="Path to tag_vocab.pt")
parser.add_argument("--output_path", type=str, required=True, help="Path to save embeddings tensor")
parser.add_argument("--encoder_type", type=str,
choices=["bert", "clip", "e5", "t5", "osm-clip", "osm-e5", "osm-bert"],
default="bert")
parser.add_argument("--checkpoint_path", type=str, default=None, help="Optional checkpoint for OSMBind")
args = parser.parse_args()
generate_embeddings(
taglist_path=args.taglist_path,
tag_vocab_path=args.tag_vocab_path,
output_path=args.output_path,
encoder_type=args.encoder_type,
checkpoint_path=args.checkpoint_path
)