| import sys |
| import os |
| sys.path.append(os.path.dirname(os.path.abspath(os.getcwd()))) |
| from stark_qa import load_skb |
|
|
|
|
| from torch.utils.data import Dataset, DataLoader |
| import torch |
| from tqdm import tqdm |
| import numpy as np |
| import torch.nn as nn |
|
|
| from Reranking.utils import move_to_cuda, seed_everything |
| from Reranking.rerankers.path import PathReranker |
| import torch.nn.functional as F |
| import argparse |
| import pickle as pkl |
|
|
|
|
|
|
|
|
| class TestDataset(Dataset): |
| """ |
| data format: { |
| "query": query, |
| "pred_dict": {node_id: score}, |
| 'score_vector_dict': {node_id: [bm25, bm_25, bm25, ada]}, |
| "text_emb_dict": {node_id: text_emb}, |
| "ans_ids": [], |
| } |
| |
| """ |
|
|
| def __init__(self, saved_data, args): |
| |
| print(f"Start processing test dataset...") |
| self.text2emb_dict = saved_data['text2emb_dict'] |
| self.data = saved_data['data'] |
| |
| self.text_emb_matrix = list(self.text2emb_dict.values()) |
| self.text_emb_matrix = torch.concat(self.text_emb_matrix, dim=0) |
| |
| |
| self.text2idx_dict = {key: idx for idx, key in enumerate(self.text2emb_dict.keys())} |
| |
| self.args = args |
| |
| |
| |
| |
| def __len__(self): |
| return len(self.data) |
|
|
| def __getitem__(self, idx): |
| |
| if self.args.dataset_name == 'amazon': |
| |
| self.data[idx]['text_emb_dict'] = {key: self.text2idx_dict[value] for key, value in self.data[idx]['text_emb_dict'].items()} |
| else: |
| |
| pred_dict = self.data[idx]['pred_dict'] |
| sorted_ids = sorted(pred_dict.keys(), key=lambda x: pred_dict[x], reverse=True) |
| |
| sorted_ids = sorted_ids[:50] |
| |
| self.data[idx]['score_vector_dict'] = {key: self.data[idx]['score_vector_dict'][key] for key in sorted_ids} |
| |
| self.data[idx]['symb_enc_dict'] = {key: self.data[idx]['symb_enc_dict'][key] for key in sorted_ids} |
| |
| self.data[idx]['text_emb_dict'] = {key: self.text2idx_dict[value] for key, value in self.data[idx]['text_emb_dict'].items()} |
| self.data[idx]['text_emb_dict'] = {key: self.data[idx]['text_emb_dict'][key] for key in sorted_ids} |
| |
| |
|
|
| return self.data[idx] |
| |
| |
| def collate_batch(self, batch): |
| |
| |
| batch_q = [batch[i]['query'] for i in range(len(batch))] |
| q_text = batch_q |
| |
| |
| batch_c = [list(batch[i]['score_vector_dict'].keys()) for i in range(len(batch))] |
| batch_c = torch.tensor(batch_c) |
| c_score_vector = [list(batch[i]['score_vector_dict'].values()) for i in range(len(batch))] |
| c_score_vector = torch.tensor(c_score_vector) |
| c_score_vector = c_score_vector[:, :, :self.args.vector_dim] |
| |
| |
| c_symb_enc = [list(batch[i]['symb_enc_dict'].values()) for i in range(len(batch))] |
| c_symb_enc = torch.tensor(c_symb_enc) |
| |
| |
| c_text_emb = [self.text_emb_matrix[list(batch[i]['text_emb_dict'].values())].unsqueeze(0) for i in range(len(batch))] |
| c_text_emb = torch.concat(c_text_emb, dim=0) |
| |
| |
| |
| ans_ids = [batch[i]['ans_ids'] for i in range(len(batch))] |
| |
| |
| pred_ids = batch_c.tolist() |
| |
| |
| |
| feed_dict = { |
| 'query': q_text, |
| 'c_score_vector': c_score_vector, |
| 'c_text_emb': c_text_emb, |
| 'c_symb_enc': c_symb_enc, |
| 'ans_ids': ans_ids, |
| 'pred_ids': pred_ids |
|
|
| } |
| |
| |
| return feed_dict |
| |
| |
| |
| def batch_evaluator(skb, scores_cand, ans_ids, batch): |
|
|
| results = {} |
| |
| |
| |
| candidates_ids = skb.candidate_ids |
| id_to_idx = {candidate_id: idx for idx, candidate_id in enumerate(candidates_ids)} |
| |
| |
| |
| pred_matrix = torch.zeros((scores_cand.shape[0],len(candidates_ids))) |
| |
| |
| |
| |
| flat_pred_ids = torch.tensor(batch['pred_ids']).flatten().tolist() |
| |
| |
| |
| pred_idx = [id_to_idx[pred_id] for pred_id in flat_pred_ids] |
| |
| |
| |
| pred_idx = torch.tensor(pred_idx).reshape(scores_cand.shape[0], -1) |
| |
| |
| pred_matrix = pred_matrix.to(scores_cand.device) |
| |
| |
| pred_matrix[torch.arange(scores_cand.shape[0]).unsqueeze(1), pred_idx] = scores_cand.squeeze(-1) |
| |
| |
| |
| |
|
|
| |
| flat_ans_idx = [id_to_idx[a_id] for sublist in ans_ids for a_id in sublist] |
|
|
| |
| row_indices = torch.repeat_interleave(torch.arange(len(ans_ids)), torch.tensor([len(sublist) for sublist in ans_ids])) |
|
|
| |
| ans_matrix = torch.zeros((scores_cand.shape[0], len(candidates_ids)), device=scores_cand.device) |
| ans_matrix[row_indices, torch.tensor(flat_ans_idx, device=scores_cand.device)] = 1 |
|
|
|
|
| |
| |
| |
| max_score, max_idx = torch.max(pred_matrix, dim=1) |
| |
| batch_hit1 = ans_matrix[torch.arange(scores_cand.shape[0]), max_idx] |
| hit1_list = batch_hit1.tolist() |
| |
| |
| |
| _, top5_idx = torch.topk(pred_matrix, 5, dim=1) |
| batch_hit5 = ans_matrix[torch.arange(scores_cand.shape[0]).unsqueeze(1), top5_idx] |
| |
| |
| batch_hit5 = torch.max(batch_hit5, dim=1)[0] |
| hit5_list = batch_hit5.tolist() |
| |
| |
| |
| |
| _, top20_idx = torch.topk(pred_matrix, 20, dim=1) |
| batch_recall20 = ans_matrix[torch.arange(scores_cand.shape[0]).unsqueeze(1), top20_idx] |
| |
| batch_recall20 = torch.sum(batch_recall20, dim=1) |
| |
| batch_recall20 = batch_recall20 / torch.sum(ans_matrix, dim=1) |
| recall20_list = batch_recall20.tolist() |
| |
| |
| |
| |
| |
| _, rank_idx = torch.sort(pred_matrix, dim=1, descending=True) |
| |
| batch_mrr = ans_matrix[torch.arange(scores_cand.shape[0]).unsqueeze(1), rank_idx] |
| |
| batch_mrr = torch.argmax(batch_mrr, dim=1) |
| |
| batch_mrr += 1 |
| |
| batch_mrr = 1 / batch_mrr.float() |
| mrr_list = batch_mrr.tolist() |
| |
|
|
| results['hit@1'] = hit1_list |
| results['hit@5'] = hit5_list |
| results['recall@20'] = recall20_list |
| results['mrr'] = mrr_list |
| |
|
|
| return results |
| |
|
|
|
|
| |
| @torch.no_grad() |
| def evaluate(router, test_loader, skb): |
|
|
| |
| router.eval() |
|
|
| all_results = { |
| "hit@1": [], |
| "hit@5": [], |
| "recall@20": [], |
| "mrr": [] |
| } |
| avg_results = { |
| "hit@1": 0, |
| "hit@5": 0, |
| "recall@20": 0, |
| "mrr": 0 |
| } |
| |
| |
| |
| pred_list = [] |
| scores_cand_list = [] |
| ans_ids_list = [] |
| print(f"Start evaluating...") |
| |
| for idx, batch in enumerate(tqdm(test_loader, desc='Evaluating', position=0)): |
| |
| batch = move_to_cuda(batch) |
| |
| |
| if isinstance(router, nn.DataParallel): |
| scores_cand = router.module.eval_batch(batch) |
| else: |
| scores_cand = router.eval_batch(batch) |
|
|
|
|
| |
| ans_ids = batch['ans_ids'] |
| |
| results = batch_evaluator(skb, scores_cand, ans_ids, batch) |
| |
| |
| for key in results.keys(): |
| all_results[key].extend(results[key]) |
| |
| |
| pred_list.extend(batch['pred_ids']) |
| scores_cand_list.extend(scores_cand.cpu().tolist()) |
| ans_ids_list.extend(ans_ids) |
| |
| |
| |
| for key in avg_results.keys(): |
| avg_results[key] = np.mean(all_results[key]) |
| |
| print(f"Results: {avg_results}") |
| |
|
|
| |
| return avg_results |
|
|
|
|
| def parse_args(): |
| |
| parser = argparse.ArgumentParser(description="Run PathRouter with dynamic combinations of embeddings.") |
| |
| |
| parser.add_argument("--dataset_name", type=str, default="mag", help="Name of the dataset.") |
| |
| |
| parser.add_argument("--device", type=str, default="cuda", help="Device to run the model (e.g., 'cuda' or 'cpu').") |
|
|
| |
| |
| parser.add_argument("--concat_num", type=int, default=0, help="Number of concatenation of embeddings.") |
| |
| |
| parser.add_argument("--checkpoint_path", type=str, default="./data/checkpoints", help="Path saves the checkpoints.") |
| |
| |
| parser.add_argument("--vector_dim", type=int, default=4, help="Dimension of the similarity vector.") |
| |
| |
| |
| args = parser.parse_args() |
| return args |
|
|
|
|
| def get_concat_num(combo): |
| """ |
| Determine the value of concat_num based on the combination of embeddings. |
| - score_vec adds +1 |
| - text_emb adds +1 |
| - symb_enc adds +3 |
| """ |
| concat_num = 0 |
| if combo.get("score_vec", False): |
| concat_num += 1 |
| if combo.get("text_emb", False): |
| concat_num += 1 |
| if combo.get("symb_enc", False): |
| concat_num += 3 |
| |
| |
| return concat_num |
|
|
|
|
| def run(test_data, skb, dataset_name): |
| |
| |
| |
| test_size = 64 |
| test_dataset = TestDataset(test_data, args=args) |
| test_loader = DataLoader(test_dataset, batch_size=test_size, num_workers=32, collate_fn=test_dataset.collate_batch) |
| |
| |
| print(f"Load the model...") |
| args.checkpoint_path = args.checkpoint_path + f"/{dataset_name}/best.pth" |
| router = PathReranker(socre_vector_input_dim=4, text_emb_input_dim=768, symb_enc_dim=3, args=args) |
| checkpoint = torch.load(args.checkpoint_path) |
| router.load_state_dict(checkpoint) |
| router = router.to(args.device) |
| |
| |
| test_results = evaluate(router, test_loader, skb) |
| print(f"Test evaluation") |
| print(test_results) |
| |
| return test_results |
|
|
| if __name__ == "__main__": |
| |
| combo = { |
| "text_emb": True, |
| "score_vec": True, |
| "symb_enc": True |
| } |
| concat_num = get_concat_num(combo) |
| |
| base_args = parse_args() |
| args = argparse.Namespace(**vars(base_args), **combo) |
| args.concat_num = concat_num |
| dataset_name = args.dataset_name |
| |
| test_data_path = f"../{dataset_name}_test.pkl" |
| with open(test_data_path, 'rb') as f: |
| test_data = pkl.load(f) |
| skb = load_skb(dataset_name) |
| results = run(test_data, skb, dataset_name) |
| |