Spaces:
Runtime error
Runtime error
| import time | |
| import json | |
| import numpy as np | |
| import streamlit as st | |
| from pathlib import Path | |
| from collections import defaultdict | |
| import sys | |
| path_root = Path("./") | |
| sys.path.append(str(path_root)) | |
| st.set_page_config(page_title="PSC Runtime", | |
| page_icon='🌸', layout="centered") | |
| name = st.selectbox( | |
| "Choose a dataset", | |
| ["dl19", "dl20"], | |
| index=None, | |
| placeholder="Choose a dataset..." | |
| ) | |
| model_name = st.selectbox( | |
| "Choose a model", | |
| ["gpt-3.5", "gpt-4"], | |
| index=None, | |
| placeholder="Choose a model..." | |
| ) | |
| if name and model_name: | |
| import torch | |
| # fn = f"dl19-gpt-3.5.pt" | |
| fn = f"{name}-{model_name}.pt" | |
| object = torch.load(fn) | |
| outputs = object[2] | |
| query2outputs = {} | |
| for output in outputs: | |
| all_queries = {x['query'] for x in output} | |
| assert len(all_queries) == 1 | |
| query = list(all_queries)[0] | |
| query2outputs[query] = [x['hits'] for x in output] | |
| search_query = st.selectbox( | |
| "Choose a query from the list", | |
| sorted(query2outputs), | |
| # index=None, | |
| # placeholder="Choose a query from the list..." | |
| ) | |
| def preferences_from_hits(list_of_hits): | |
| docid2id = {} | |
| id2doc = {} | |
| preferences = [] | |
| for result in list_of_hits: | |
| for doc in result: | |
| if doc["docid"] not in docid2id: | |
| id = len(docid2id) | |
| docid2id[doc["docid"]] = id | |
| id2doc[id] = doc | |
| print([doc["docid"] for doc in result]) | |
| print([docid2id[doc["docid"]] for doc in result]) | |
| preferences.append([docid2id[doc["docid"]] for doc in result]) | |
| # = {v: k for k, v in docid2id.items()} | |
| return np.array(preferences), id2doc | |
| def load_qrels(name): | |
| import ir_datasets | |
| if name == "dl19": | |
| ds_name = "msmarco-passage/trec-dl-2019/judged" | |
| elif name == "dl20": | |
| ds_name = "msmarco-passage/trec-dl-2020/judged" | |
| else: | |
| raise ValueError(name) | |
| dataset = ir_datasets.load(ds_name) | |
| qrels = defaultdict(dict) | |
| for qrel in dataset.qrels_iter(): | |
| qrels[qrel.query_id][qrel.doc_id] = qrel.relevance | |
| return qrels | |
| def aggregate(list_of_hits): | |
| import numpy as np | |
| from permsc import KemenyOptimalAggregator, sum_kendall_tau, ranks_from_preferences | |
| from permsc import BordaRankAggregator | |
| preferences, id2doc = preferences_from_hits(list_of_hits) | |
| y_optimal = KemenyOptimalAggregator().aggregate(preferences) | |
| # y_optimal = BordaRankAggregator().aggregate(preferences) | |
| return [id2doc[id] for id in y_optimal] | |
| def write_ranking(search_results, text): | |
| st.write(f'<p align=\"right\" style=\"color:grey;\"> {text} ms</p>', unsafe_allow_html=True) | |
| qid = {result["qid"] for result in search_results} | |
| assert len(qid) == 1 | |
| qid = list(qid)[0] | |
| for i, result in enumerate(search_results): | |
| result_id = result["docid"] | |
| contents = result["content"] | |
| label = qrels[str(qid)].get(str(result_id), 0) | |
| if label == 3: | |
| style = "style=\"color:rgb(231, 95, 43);\"" | |
| elif label == 2: | |
| style = "style=\"color:rgb(238, 147, 49);\"" | |
| elif label == 1: | |
| style = "style=\"color:rgb(241, 177, 118);\"" | |
| else: | |
| style = "style=\"color:grey;\"" | |
| print(qid, result_id, label, style) | |
| # output = f'<div class="row"> <b>Rank</b>: {i+1} | <b>Document ID</b>: {result_id} | <b>Score</b>:{result_score:.2f}</div>' | |
| output = f'<div class="row" {style}> <b>Rank</b>: {i+1} | <b>Document ID</b>: {result_id}' | |
| try: | |
| st.write(output, unsafe_allow_html=True) | |
| st.write( | |
| f'<div class="row" {style}>{contents}</div>', unsafe_allow_html=True) | |
| except: | |
| pass | |
| st.write('---') | |
| aggregated_ranking = aggregate(query2outputs[search_query]) | |
| qrels = load_qrels(name) | |
| col1, col2 = st.columns([5, 5]) | |
| if search_query: | |
| with col1: | |
| if search_query or button_clicked: | |
| write_ranking(search_results=query2outputs[search_query][0], "w/o PSC") | |
| with col2: | |
| if search_query or button_clicked: | |
| write_ranking(search_results=aggregated_ranking, "w/ PSC") | |