| import sys | |
| import os | |
| import os.path as osp | |
| from typing import Any, Union, List, Dict | |
| import torch | |
| import torch.nn as nn | |
| from stark_qa.tools.api import get_api_embeddings, get_sentence_transformer_embeddings, get_contriever_embeddings | |
| from stark_qa.tools.local_encoder import get_llm2vec_embeddings, get_gritlm_embeddings | |
| from stark_qa.evaluator import Evaluator | |
| class ModelForSTaRKQA(nn.Module): | |
| def __init__(self, skb, query_emb_dir='.'): | |
| """ | |
| Initializes the model with the given knowledge base. | |
| Args: | |
| skb: Knowledge base containing candidate information. | |
| """ | |
| super(ModelForSTaRKQA, self).__init__() | |
| self.skb = skb | |
| self.candidate_ids = skb.candidate_ids | |
| self.num_candidates = skb.num_candidates | |
| self.query_emb_dir = query_emb_dir | |
| query_emb_path = osp.join(self.query_emb_dir, 'query_emb_dict.pt') | |
| if os.path.exists(query_emb_path): | |
| print(f'Load query embeddings from {query_emb_path}') | |
| self.query_emb_dict = torch.load(query_emb_path) | |
| else: | |
| self.query_emb_dict = {} | |
| self.evaluator = Evaluator(self.candidate_ids) | |
| def forward(self, | |
| query: Union[str, List[str]], | |
| candidates: List[int] = None, | |
| query_id: Union[int, List[int]] = None, | |
| **kwargs: Any) -> Dict[str, Any]: | |
| """ | |
| Forward pass to compute predictions for the given query. | |
| Args: | |
| query (Union[str, list]): Query string or a list of query strings. | |
| candidates (Union[list, None]): A list of candidate ids (optional). | |
| query_id (Union[int, list, None]): Query index (optional). | |
| Returns: | |
| pred_dict (dict): A dictionary of predicted scores or answer ids. | |
| """ | |
| raise NotImplementedError | |
| def get_query_emb(self, | |
| query: Union[str, List[str]], | |
| query_id: Union[int, List[int]], | |
| emb_model: str = 'text-embedding-ada-002', | |
| **encode_kwargs) -> torch.Tensor: | |
| """ | |
| Retrieves or computes the embedding for the given query. | |
| Args: | |
| query (str): Query string. | |
| query_id (int): Query index. | |
| emb_model (str): Embedding model to use. | |
| Returns: | |
| query_emb (torch.Tensor): Query embedding. | |
| """ | |
| if isinstance(query_id, int): | |
| query_id = [query_id] | |
| if isinstance(query, str): | |
| query = [query] | |
| if query_id is None: | |
| query_emb = get_embeddings(query, emb_model, **encode_kwargs) | |
| elif set(query_id).issubset(set(list(self.query_emb_dict.keys()))): | |
| query_emb = torch.concat([self.query_emb_dict[qid] for qid in query_id], dim=0) | |
| else: | |
| query_emb = get_embeddings(query, emb_model, **encode_kwargs) | |
| for qid, emb in zip(query_id, query_emb): | |
| self.query_emb_dict[qid] = emb.view(1, -1) | |
| torch.save(self.query_emb_dict, osp.join(self.query_emb_dir, 'query_emb_dict.pt')) | |
| query_emb = query_emb.view(len(query), -1) | |
| return query_emb | |
| def evaluate(self, | |
| pred_dict: Dict[int, float], | |
| answer_ids: Union[torch.LongTensor, List[Any]], | |
| metrics: List[str] = ['mrr', 'hit@3', 'recall@20'], | |
| **kwargs: Any) -> Dict[str, float]: | |
| """ | |
| Evaluates the predictions using the specified metrics. | |
| Args: | |
| pred_dict (Dict[int, float]): Predicted answer ids or scores. | |
| answer_ids (torch.LongTensor): Ground truth answer ids. | |
| metrics (List[str]): A list of metrics to be evaluated, including 'mrr', 'hit@k', 'recall@k', | |
| 'precision@k', 'map@k', 'ndcg@k'. | |
| Returns: | |
| Dict[str, float]: A dictionary of evaluation metrics. | |
| """ | |
| return self.evaluator(pred_dict, answer_ids, metrics) | |
| def evaluate_batch(self, | |
| pred_ids: List[int], | |
| pred: torch.Tensor, | |
| answer_ids: Union[torch.LongTensor, List[Any]], | |
| metrics: List[str] = ['mrr', 'hit@3', 'recall@20'], | |
| **kwargs: Any) -> Dict[str, float]: | |
| return self.evaluator.evaluate_batch(pred_ids, pred, answer_ids, metrics) | |
| def get_embeddings(text, model_name, **encode_kwargs): | |
| """ | |
| Get embeddings for the given text using the specified model. | |
| Args: | |
| model_name (str): Model name. | |
| text (Union[str, List[str]]): The input text to be embedded. | |
| Returns: | |
| torch.Tensor: Embedding of the input text. | |
| """ | |
| if isinstance(text, str): | |
| text = [text] | |
| if 'GritLM' in model_name: | |
| emb = get_gritlm_embeddings(text, model_name, **encode_kwargs) | |
| elif 'LLM2Vec' in model_name: | |
| emb = get_llm2vec_embeddings(text, model_name, **encode_kwargs) | |
| elif 'all-mpnet-base-v2' in model_name: | |
| emb = get_sentence_transformer_embeddings(text) | |
| elif 'contriever' in model_name: | |
| emb = get_contriever_embeddings(text) | |
| else: | |
| emb = get_api_embeddings(text, model_name, **encode_kwargs) | |
| return emb.view(len(text), -1) | |