|
|
import os |
|
|
import argparse |
|
|
import torch |
|
|
from tqdm import tqdm |
|
|
from transformers import ( |
|
|
AutoTokenizer, AutoModel, |
|
|
BertTokenizer, BertModel, |
|
|
CLIPTokenizer, CLIPTextModel, |
|
|
T5Tokenizer, T5EncoderModel |
|
|
) |
|
|
|
|
|
import sys |
|
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "osm_clip"))) |
|
|
from model import OSMBind |
|
|
|
|
|
|
|
|
def average_pool(last_hidden_states, attention_mask): |
|
|
"""Computes average pooling of hidden states, masking padding tokens.""" |
|
|
masked_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) |
|
|
return masked_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] |
|
|
|
|
|
|
|
|
def get_tokenizer_and_model(encoder_type='bert', checkpoint_path=None, taglist_path = None, tagvocab_path = None): |
|
|
if encoder_type == 'bert': |
|
|
model_name = 'bert-base-uncased' |
|
|
tokenizer = BertTokenizer.from_pretrained(model_name) |
|
|
model = BertModel.from_pretrained(model_name) |
|
|
embedding_fn = lambda outputs, batch_dict: outputs.pooler_output.squeeze() |
|
|
|
|
|
elif encoder_type == 'clip': |
|
|
model_name = 'openai/clip-vit-large-patch14' |
|
|
tokenizer = CLIPTokenizer.from_pretrained(model_name) |
|
|
model = CLIPTextModel.from_pretrained(model_name) |
|
|
|
|
|
def clip_embedding_fn(outputs, batch_dict): |
|
|
input_ids = batch_dict['input_ids'] |
|
|
eos_token_id = tokenizer.eos_token_id |
|
|
seq_lengths = (input_ids == eos_token_id).nonzero(as_tuple=True)[1] |
|
|
|
|
|
embeddings = [] |
|
|
for i in range(input_ids.size(0)): |
|
|
eos_pos = seq_lengths[i] if i < len(seq_lengths) else (input_ids[i] != tokenizer.pad_token_id).sum() - 1 |
|
|
embeddings.append(outputs.last_hidden_state[i, eos_pos, :]) |
|
|
return torch.stack(embeddings) |
|
|
|
|
|
embedding_fn = clip_embedding_fn |
|
|
|
|
|
elif encoder_type == 'e5': |
|
|
model_name = 'intfloat/e5-base-v2' |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
model = AutoModel.from_pretrained(model_name) |
|
|
embedding_fn = lambda outputs, batch_dict: average_pool(outputs.last_hidden_state, batch_dict['attention_mask']) |
|
|
|
|
|
elif encoder_type == 't5': |
|
|
model_name = 't5-base' |
|
|
tokenizer = T5Tokenizer.from_pretrained(model_name) |
|
|
model = T5EncoderModel.from_pretrained(model_name) |
|
|
embedding_fn = lambda outputs, batch_dict: average_pool(outputs.last_hidden_state, batch_dict['attention_mask']) |
|
|
|
|
|
elif 'osm' in encoder_type: |
|
|
text_backbone = encoder_type.split('-')[1] if '-' in encoder_type else 'clip' |
|
|
model = OSMBind(taglist_path=taglist_path, tagvocab_path=tagvocab_path, text_backbone=text_backbone) |
|
|
ckpt = torch.load(checkpoint_path, map_location='cpu') |
|
|
model.load_state_dict(ckpt['state_dict'], strict=False) |
|
|
model.eval().cuda() |
|
|
tokenizer = None |
|
|
|
|
|
def osm_embedding_fn(outputs, batch_dict): |
|
|
return model.text_encoder.encode_batch(batch_dict['sentences']) |
|
|
|
|
|
embedding_fn = osm_embedding_fn |
|
|
|
|
|
else: |
|
|
raise ValueError(f"Unsupported encoder_type: {encoder_type}") |
|
|
|
|
|
model.eval() |
|
|
return tokenizer, model, embedding_fn |
|
|
|
|
|
|
|
|
def generate_embeddings(taglist_path, tag_vocab_path, output_path, |
|
|
encoder_type='bert', checkpoint_path=None): |
|
|
|
|
|
taglist = torch.load(taglist_path, weights_only = True) |
|
|
tag_vocab = torch.load(tag_vocab_path, weights_only = True) |
|
|
tag_index = {v: k for k, v in tag_vocab.items()} |
|
|
|
|
|
|
|
|
sentences = [] |
|
|
for tl in taglist: |
|
|
words = [tag_index[idx] for idx in tl] |
|
|
sentences.append(" ".join(words)) |
|
|
|
|
|
|
|
|
if encoder_type == 'e5': |
|
|
sentences = [f"query: {s}" for s in sentences] |
|
|
elif encoder_type == 't5': |
|
|
sentences = [f"embedding: {s}" for s in sentences] |
|
|
|
|
|
|
|
|
tokenizer, model, embedding_fn = get_tokenizer_and_model(encoder_type, checkpoint_path, taglist_path = taglist_path, tagvocab_path = tag_vocab_path) |
|
|
device = next(model.parameters()).device if hasattr(model, 'parameters') else torch.device('cpu') |
|
|
|
|
|
|
|
|
embeddings = [] |
|
|
print("Encoding taglists...") |
|
|
for sentence in tqdm(sentences): |
|
|
if 'osm' in encoder_type: |
|
|
batch_dict = {'sentences': [sentence]} |
|
|
outputs = None |
|
|
else: |
|
|
inputs = tokenizer([sentence], return_tensors='pt', padding=True, truncation=True) |
|
|
batch_dict = {k: v.to(device) for k, v in inputs.items()} |
|
|
outputs = model(**batch_dict) |
|
|
|
|
|
with torch.inference_mode(): |
|
|
emb = embedding_fn(outputs, batch_dict) |
|
|
if emb.ndim == 1: |
|
|
emb = emb.unsqueeze(0) |
|
|
embeddings.append(emb.cpu()) |
|
|
|
|
|
embeddings = torch.cat(embeddings, dim=0) |
|
|
os.makedirs(os.path.dirname(output_path), exist_ok=True) |
|
|
torch.save(embeddings, output_path) |
|
|
print(f"Saved {len(sentences)} taglist embeddings to {output_path}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser(description="Generate embeddings for taglists") |
|
|
parser.add_argument("--taglist_path", type=str, required=True, help="Path to taglist_vocab.pt") |
|
|
parser.add_argument("--tag_vocab_path", type=str, required=True, help="Path to tag_vocab.pt") |
|
|
parser.add_argument("--output_path", type=str, required=True, help="Path to save embeddings tensor") |
|
|
parser.add_argument("--encoder_type", type=str, |
|
|
choices=["bert", "clip", "e5", "t5", "osm-clip", "osm-e5", "osm-bert"], |
|
|
default="bert") |
|
|
parser.add_argument("--checkpoint_path", type=str, default=None, help="Optional checkpoint for OSMBind") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
generate_embeddings( |
|
|
taglist_path=args.taglist_path, |
|
|
tag_vocab_path=args.tag_vocab_path, |
|
|
output_path=args.output_path, |
|
|
encoder_type=args.encoder_type, |
|
|
checkpoint_path=args.checkpoint_path |
|
|
) |