| import math | |
| def get_topk_results(predictions, scores, targets, k, all_items=None): | |
| results = [] | |
| B = len(targets) | |
| predictions = [_.split("Response:")[-1] for _ in predictions] | |
| predictions = [_.strip().replace(" ","") for _ in predictions] | |
| if all_items is not None: | |
| for i, seq in enumerate(predictions): | |
| if seq not in all_items: | |
| scores[i] = -1000 | |
| for b in range(B): | |
| batch_seqs = predictions[b * k: (b + 1) * k] | |
| batch_scores = scores[b * k: (b + 1) * k] | |
| pairs = [(a, b) for a, b in zip(batch_seqs, batch_scores)] | |
| sorted_pairs = sorted(pairs, key=lambda x: x[1], reverse=True) | |
| target_item = targets[b] | |
| one_results = [] | |
| for sorted_pred in sorted_pairs: | |
| if sorted_pred[0] == target_item: | |
| one_results.append(1) | |
| else: | |
| one_results.append(0) | |
| results.append(one_results) | |
| return results | |
| def get_metrics_results(topk_results, metrics): | |
| res = {} | |
| for m in metrics: | |
| if m.lower().startswith("hit"): | |
| k = int(m.split("@")[1]) | |
| res[m] = hit_k(topk_results, k) | |
| elif m.lower().startswith("ndcg"): | |
| k = int(m.split("@")[1]) | |
| res[m] = ndcg_k(topk_results, k) | |
| else: | |
| raise NotImplementedError | |
| return res | |
| def ndcg_k(topk_results, k): | |
| ndcg = 0.0 | |
| for row in topk_results: | |
| res = row[:k] | |
| one_ndcg = 0.0 | |
| for i in range(len(res)): | |
| one_ndcg += res[i] / math.log(i + 2, 2) | |
| ndcg += one_ndcg | |
| return ndcg | |
| def hit_k(topk_results, k): | |
| hit = 0.0 | |
| for row in topk_results: | |
| res = row[:k] | |
| if sum(res) > 0: | |
| hit += 1 | |
| return hit | |