| import os | |
| import collections | |
| import json | |
| import logging | |
| import argparse | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| from time import time | |
| from torch import optim | |
| from tqdm import tqdm | |
| from torch.utils.data import DataLoader | |
| from rq_llama import * | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description = "Index") | |
| parser.add_argument("--ckpt_path", type = str, default = "", help = "") | |
| parser.add_argument("--save_path", type = str, default = "", help = "") | |
| parser.add_argument("--device_map", type = str, default = "1", help = "gpu or cpu") | |
| return parser.parse_args() | |
| args = parse_args() | |
| print(args) | |
| device_map = {'': int(args.device_map)} | |
| MODEL = LlamaWithRQ.from_pretrained(args.ckpt_path, torch_dtype = torch.float16, low_cpu_mem_usage = True, device_map = device_map) | |
| MODEL.eval() | |
| device = MODEL.device | |
| llama = MODEL.model.get_decoder() | |
| tokenizer = MODEL.tokenizer | |
| item_texts = MODEL.item_texts | |
| all_idx = [] | |
| all_embeddings = [] | |
| with torch.no_grad(): | |
| for idx, text in tqdm(item_texts.items()): | |
| item_text = text['title'] + ' ' + text['description'] | |
| item_ids = tokenizer(item_text, return_tensors = 'pt', padding = True, truncation = True).to(device) | |
| item_emb = llama(input_ids = item_ids.input_ids, attention_mask = item_ids.attention_mask) | |
| item_emb = item_emb.last_hidden_state * item_ids.attention_mask.unsqueeze(-1) | |
| item_emb = item_emb.sum(dim = 1) / item_ids.attention_mask.sum(dim = -1, keepdim = True) | |
| all_idx.append(idx) | |
| all_embeddings.append(item_emb.detach().cpu().numpy().flatten().tolist()) | |
| results = { | |
| 'id': all_idx, | |
| 'emb': [] | |
| } | |
| for emb in tqdm(all_embeddings): | |
| str_emb = '' | |
| for e in emb: | |
| str_emb = str_emb + str(e) + ' ' | |
| results['emb'].append(str_emb[:-1]) | |
| df = pd.DataFrame(results) | |
| df.to_csv(args.save_path, sep = '\t', header = 0, index = False) |