| from transformers import Owlv2TextModel, Owlv2Processor, AutoTokenizer | |
| import json | |
| import torch | |
| from torch import nn | |
| import tqdm | |
| embed_dict = nn.ParameterDict() | |
| bsz = 8 | |
| with open("id_to_str.json") as f: | |
| data = json.load(f) | |
| keys = list(data.keys()) | |
| bar = tqdm.tqdm(range(len(keys)//bsz)) | |
| proc = Owlv2Processor.from_pretrained("google/owlv2-base-patch16-ensemble") | |
| tokenizer = AutoTokenizer.from_pretrained("google/owlv2-base-patch16-ensemble") | |
| model = Owlv2TextModel.from_pretrained("google/owlv2-base-patch16-ensemble") | |
| for i in bar: | |
| batch = [data[key].replace("_", " ") for key in keys[i*bsz:(i+1)*bsz]] | |
| tokenized = tokenizer(batch) | |
| for k in range(bsz): | |
| if len(tokenized[k]) > 16: | |
| tokenizer.decode(tokenized[k]) | |
| batch = proc(text=batch, return_tensors="pt") | |
| output = model(**batch) | |
| for k, key in enumerate(keys[i*bsz:(i+1)*bsz]): | |
| embed_dict[key] = output.pooler_output[k, :] | |
| torch.save(embed_dict.state_dict(), "embeds.pt") |