| import os.path as osp | |
| import torch | |
| from typing import Any, Union, List, Dict | |
| from models.model import ModelForSTaRKQA | |
| from tqdm import tqdm | |
| from stark_qa.evaluator import Evaluator | |
| import sys | |
| sys.path.append("stark/") | |
| class VSS(ModelForSTaRKQA): | |
| def __init__(self, | |
| skb, | |
| query_emb_dir: str, | |
| candidates_emb_dir: str, | |
| emb_model: str = 'text-embedding-ada-002', | |
| device: str = 'cuda'): | |
| """ | |
| Vector Similarity Search | |
| Args: | |
| skb (SemiStruct): Knowledge base. | |
| query_emb_dir (str): Directory to query embeddings. | |
| candidates_emb_dir (str): Directory to candidate embeddings. | |
| emb_model (str): Embedding model name. | |
| """ | |
| super(VSS, self).__init__(skb, query_emb_dir=query_emb_dir) | |
| self.emb_model = emb_model | |
| self.candidates_emb_dir = candidates_emb_dir | |
| self.device = device | |
| self.evaluator = Evaluator(self.candidate_ids, device) | |
| candidate_emb_path = osp.join(candidates_emb_dir, 'candidate_emb_dict.pt') | |
| candidate_emb_dict = torch.load(candidate_emb_path) | |
| print(f'Loaded candidate_emb_dict from {candidate_emb_path}!') | |
| assert len(candidate_emb_dict) == len(self.candidate_ids) | |
| candidate_embs = [candidate_emb_dict[idx].view(1, -1) for idx in self.candidate_ids] | |
| self.candidate_embs = torch.cat(candidate_embs, dim=0).to(device) | |
| def forward(self, | |
| query: Union[str, List[str]], | |
| query_id: Union[int, List[int]], | |
| **kwargs: Any) -> dict: | |
| """ | |
| Forward pass to compute similarity scores for the given query. | |
| Args: | |
| query (str): Query string. | |
| query_id (int): Query index. | |
| Returns: | |
| pred_dict (dict): A dictionary of candidate ids and their corresponding similarity scores. | |
| """ | |
| query_emb = self.get_query_emb(query, query_id, emb_model=self.emb_model, **kwargs) | |
| similarity = torch.matmul(query_emb.to(self.device), self.candidate_embs.T).cpu() | |
| if isinstance(query, str): | |
| return dict(zip(self.candidate_ids, similarity.view(-1))) | |
| else: | |
| return torch.LongTensor(self.candidate_ids), similarity.t() | |