Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import os | |
| import csv | |
| import json | |
| import uuid | |
| import random | |
| import pickle | |
| from langchain.vectorstores import FAISS | |
| from langchain.embeddings import HuggingFaceEmbeddings | |
| from googleapiclient.discovery import build | |
| from google.oauth2 import service_account | |
| USER_ID = uuid.uuid4() | |
| SERVICE_ACCOUNT_JSON = os.environ.get('GOOGLE_SHEET_CREDENTIALS') | |
| creds = service_account.Credentials.from_service_account_info(json.loads(SERVICE_ACCOUNT_JSON)) | |
| SPREADSHEET_ID = '1o0iKPxWYKYKEPjqB2YwrTgrLzvGyb9ULj9tnw_cfJb0' | |
| service = build('sheets', 'v4', credentials=creds) | |
| LEFT_MODEL = None | |
| RIGHT_MODEL = None | |
| PROMPT = None | |
| with open("article_list.pkl","rb") as articles: | |
| article_list = tuple(pickle.load(articles)) | |
| INDEXES = ["miread_large", "miread_contrastive", "scibert_contrastive"] | |
| MODELS = [ | |
| "biodatlab/MIReAD-Neuro-Large", | |
| "biodatlab/MIReAD-Neuro-Contrastive", | |
| "biodatlab/SciBERT-Neuro-Contrastive", | |
| ] | |
| model_kwargs = {'device': 'cpu'} | |
| encode_kwargs = {'normalize_embeddings': False} | |
| faiss_embedders = [HuggingFaceEmbeddings( | |
| model_name=name, | |
| model_kwargs=model_kwargs, | |
| encode_kwargs=encode_kwargs) for name in MODELS] | |
| vecdbs = [FAISS.load_local(index_name, faiss_embedder) | |
| for index_name, faiss_embedder in zip(INDEXES, faiss_embedders)] | |
| def get_matchup(): | |
| global LEFT_MODEL, RIGHT_MODEL | |
| choices = INDEXES | |
| left, right = random.sample(choices,2) | |
| LEFT_MODEL, RIGHT_MODEL = left, right | |
| return left, right | |
| def get_comp(prompt): | |
| global PROMPT | |
| left, right = get_matchup() | |
| left_output = inference(PROMPT,left) | |
| right_output = inference(PROMPT,right) | |
| return left_output, right_output | |
| def get_article(): | |
| return random.choice(article_list) | |
| def send_result(l_output, r_output, prompt, pick): | |
| global PROMPT | |
| global LEFT_MODEL, RIGHT_MODEL | |
| # with open('results.csv','a') as res_file: | |
| # writer = csv.writer(res_file) | |
| # writer.writerow(row) | |
| if (pick=='left'): | |
| pick = LEFT_MODEL | |
| else: | |
| pick = RIGHT_MODEL | |
| row = [USER_ID,PROMPT,LEFT_MODEL,RIGHT_MODEL,pick] | |
| row = [str(x) for x in row] | |
| body = {'values': [row]} | |
| result = service.spreadsheets().values().append(spreadsheetId=SPREADSHEET_ID, range='A1:E1', valueInputOption='RAW', body=body).execute() | |
| print(f"Appended {result['updates']['updatedCells']} cells.") | |
| new_prompt = get_article() | |
| PROMPT = new_prompt | |
| return new_prompt,gr.State.update(value=new_prompt) | |
| def get_matches(query, db_name="miread_contrastive"): | |
| """ | |
| Wrapper to call the similarity search on the required index | |
| """ | |
| matches = vecdbs[INDEXES.index( | |
| db_name)].similarity_search_with_score(query, k=30) | |
| return matches | |
| def inference(query, model="miread_contrastive"): | |
| """ | |
| This function processes information retrieved by the get_matches() function | |
| Returns - Gradio update commands for the authors, abstracts and journals tablular output | |
| """ | |
| matches = get_matches(query, model) | |
| auth_counts = {} | |
| n_table = [] | |
| scores = [round(match[1].item(), 3) for match in matches] | |
| min_score = min(scores) | |
| max_score = max(scores) | |
| def normaliser(x): return round(1 - (x-min_score)/max_score, 3) | |
| i = 1 | |
| for match in matches: | |
| doc = match[0] | |
| score = round(normaliser(round(match[1].item(), 3)), 3) | |
| title = doc.metadata['title'] | |
| author = doc.metadata['authors'][0].title() | |
| date = doc.metadata.get('date', 'None') | |
| link = doc.metadata.get('link', 'None') | |
| # For authors | |
| record = [score, | |
| author, | |
| title, | |
| link, | |
| date] | |
| if auth_counts.get(author, 0) < 2: | |
| n_table.append([i,]+record) | |
| i += 1 | |
| if auth_counts.get(author, 0) == 0: | |
| auth_counts[author] = 1 | |
| else: | |
| auth_counts[author] += 1 | |
| n_output = gr.Dataframe.update(value=n_table[:10], visible=True) | |
| return n_output | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# NBDT Recommendation Engine Arena") | |
| gr.Markdown("NBDT Recommendation Engine Arena is a tool designed to compare neuroscience abstract recommendations by our models. \ | |
| We will use this data to compare the performance of models and their preference by various neuroscientists.\ | |
| Click on the 'Get Comparision' button to run two random models on the displayed prompt. Then use the correct 'Model X is Better' button to give your vote.\ | |
| All models were trained on data provided to us by the NBDT Journal.") | |
| article = get_article() | |
| models = gr.State(value=get_matchup()) | |
| prompt = gr.State(value=article) | |
| PROMPT = article | |
| abst = gr.Textbox(value = article, label="Abstract", lines=10) | |
| action_btn = gr.Button(value="Get comparison") | |
| with gr.Group(): | |
| with gr.Row().style(equal_height=True): | |
| with gr.Column(scale=1): | |
| l_output = gr.Dataframe( | |
| headers=['No.', 'Score', 'Name', 'Title', 'Link', 'Date'], | |
| datatype=['number', 'number', 'str', 'str', 'str', 'str'], | |
| col_count=(6, "fixed"), | |
| wrap=True, | |
| visible=True, | |
| label='Model A', | |
| show_label = True, | |
| overflow_row_behaviour='paginate', | |
| scale=1 | |
| ) | |
| with gr.Column(scale=1): | |
| r_output = gr.Dataframe( | |
| headers=['No.', 'Score', 'Name', 'Title', 'Link', 'Date'], | |
| datatype=['number', 'number', 'str', 'str', 'str', 'str'], | |
| col_count=(6, "fixed"), | |
| wrap=True, | |
| visible=True, | |
| label='Model B', | |
| show_label = True, | |
| overflow_row_behaviour='paginate', | |
| scale=1 | |
| ) | |
| with gr.Row().style(equal_height=True): | |
| l_btn = gr.Button(value="Model A is better",scale=1) | |
| r_btn = gr.Button(value="Model B is better",scale=1) | |
| action_btn.click(fn=get_comp, | |
| inputs=[prompt,], | |
| outputs=[l_output, r_output], | |
| api_name="arena") | |
| l_btn.click(fn=lambda x,y,z: send_result(x,y,z,'left'), | |
| inputs=[l_output,r_output,prompt], | |
| outputs=[abst,], | |
| api_name="feedleft") | |
| r_btn.click(fn=lambda x,y,z: send_result(x,y,z,'right'), | |
| inputs=[l_output,r_output,prompt], | |
| outputs=[abst,prompt], | |
| api_name="feedright") | |
| demo.launch(debug=True) |