|
|
| """ |
| input: query, query_id, candidates_ids |
| output: pred_dict: {node_id: similarity} |
| """ |
|
|
| import sys |
| from pathlib import Path |
| |
| current_file = Path(__file__).resolve() |
| project_root = current_file.parents[2] |
| |
| sys.path.append(str(project_root)) |
| import torch |
| from Reasoning.text_retrievers.stark_model import ModelForSTaRKQA |
|
|
| class Ada(ModelForSTaRKQA): |
| def __init__(self, skb, dataset_name, device): |
| super(Ada, self).__init__(skb) |
| self.emb_dir = f"{project_root}/Reasoning/data/emb/{dataset_name}/" |
| self.query_emb_path = self.emb_dir + "text-embedding-ada-002/query/query_emb_dict.pt" |
| self.query_emb_dict = torch.load(self.query_emb_path) |
| |
|
|
| self.candidate_emb_path = self.emb_dir + "text-embedding-ada-002/doc/candidate_emb_dict.pt" |
| self.candidate_emb_dict = torch.load(self.candidate_emb_path) |
| self.device = device |
|
|
| assert len(self.candidate_emb_dict) == len(self.candidate_ids) |
|
|
| candidate_embs = [self.candidate_emb_dict[idx].view(1, -1) for idx in self.candidate_ids] |
| self.candidate_embs = torch.cat(candidate_embs, dim=0).to(device) |
|
|
| def score(self, query, q_id, candidate_ids): |
| """ |
| pred_dict[node_id] = similarity (tensor) |
| |
| """ |
| |
| query_emb = self.query_emb_dict[q_id].view(1, -1) |
| |
| |
| candi_embs = [self.candidate_emb_dict[c_id].view(1, -1) for c_id in candidate_ids] |
| candidates_embs = torch.cat(candi_embs, dim=0).to(self.device) |
| |
| similarity = torch.matmul(query_emb.to(self.device), candidates_embs.T).squeeze(dim=0).cpu() |
| pred_dict = {} |
| for i in range(len(candidate_ids)): |
| pred_dict[candidate_ids[i]] = similarity[i].item() |
| |
| return pred_dict |
| |
| def retrieve(self, query, q_id, topk, node_type=None): |
| |
| query_emb = self.query_emb_dict[q_id].view(1, -1) |
|
|
| similarity = torch.matmul(query_emb.to(self.device), self.candidate_embs.T).cpu() |
| if isinstance(query, str): |
| pred_dict = dict(zip(self.candidate_ids, similarity.view(-1))) |
|
|
| sorted_pred_ids = sorted(pred_dict, key=lambda x: pred_dict[x], reverse=True) |
| selected_pred_ids = sorted_pred_ids[:topk] |
| pred_dict = {id: pred_dict[id].item() for id in selected_pred_ids} |
| print(f"sorted: {pred_dict}") |
|
|
| return pred_dict |
|
|