Spaces:
Runtime error
Runtime error
| import base64 | |
| import re | |
| import json | |
| import pandas as pd | |
| import gradio as gr | |
| import pyterrier as pt | |
| pt.init() | |
| import pyt_splade | |
| from pyterrier_gradio import Demo, MarkdownFile, interface, df2code, code2md, EX_Q, EX_D | |
| factory_max = pyt_splade.SpladeFactory(agg='max') | |
| factory_sum = pyt_splade.SpladeFactory(agg='sum') | |
| COLAB_NAME = 'pyterrier_splade.ipynb' | |
| COLAB_INSTALL = ''' | |
| !pip install -q git+https://github.com/naver/splade | |
| !pip install -q git+https://github.com/seanmacavaney/pyt_splade@misc | |
| '''.strip() | |
| def generate_vis(df, mode='Document'): | |
| if len(df) == 0: | |
| return '' | |
| result = [] | |
| if mode == 'Document': | |
| max_score = max(max(t.values()) for t in df['toks']) | |
| for row in df.itertuples(index=False): | |
| if mode == 'Query': | |
| tok_scores = {m.group(2): float(m.group(1)) for m in re.finditer(r'#combine:0=([0-9.]+)\((#base64\([^)]+\)|[^)]+)\)', row.query)} | |
| for key, value in list(tok_scores.items()): | |
| if key.startswith('#base64('): | |
| b64 = re.search('#base64\(([^)]+)\)', key).group(1) | |
| del tok_scores[key] | |
| key = base64.b64decode(b64).decode() | |
| tok_scores[key] = value | |
| max_score = max(tok_scores.values()) | |
| orig_tokens = factory_max.tokenizer.tokenize(row.query_0) | |
| id = row.qid | |
| else: | |
| tok_scores = row.toks | |
| orig_tokens = factory_max.tokenizer.tokenize(row.text) | |
| id = row.docno | |
| def toks2span(toks): | |
| return '<kbd> </kbd>'.join(f'<kbd style="background-color: rgba(66, 135, 245, {tok_scores.get(t, 0)/max_score});">{t}</kbd>' for t in toks) | |
| orig_tokens_set = set(orig_tokens) | |
| exp_tokens = [t for t, v in sorted(tok_scores.items(), key=lambda x: (-x[1], x[0])) if t not in orig_tokens_set] | |
| result.append(f''' | |
| <div style="font-size: 1.2em;">{mode}: <strong>{id}</strong></div> | |
| <div style="margin: 4px 0 16px; padding: 4px; border: 1px solid black;"> | |
| <div> | |
| {toks2span(orig_tokens)} | |
| </div> | |
| <div><strong>Expansion Tokens:</strong> {toks2span(exp_tokens)}</div> | |
| </div> | |
| ''') | |
| return '\n'.join(result) | |
| def predict_query(input, agg): | |
| code = f'''import pandas as pd | |
| import pyterrier as pt ; pt.init() | |
| import pyt_splade | |
| splade = pyt_splade.SpladeFactory(agg={repr(agg)}) | |
| query_pipeline = splade.query() | |
| query_pipeline({df2code(input)}) | |
| ''' | |
| pipeline = { | |
| 'max': factory_max, | |
| 'sum': factory_sum | |
| }[agg].query() | |
| res = pipeline(input) | |
| vis = generate_vis(res, mode='Query') | |
| return (res, code2md(code, COLAB_INSTALL, COLAB_NAME), vis) | |
| def predict_doc(input, agg): | |
| code = f'''import pandas as pd | |
| import pyterrier as pt ; pt.init() | |
| import pyt_splade | |
| splade = pyt_splade.SpladeFactory(agg={repr(agg)}) | |
| doc_pipeline = splade.indexing() | |
| doc_pipeline({df2code(input)}) | |
| ''' | |
| pipeline = { | |
| 'max': factory_max, | |
| 'sum': factory_sum | |
| }[agg].indexing() | |
| res = pipeline(input) | |
| vis = generate_vis(res, mode='Document') | |
| res['toks'] = [json.dumps({k: round(v, 4) for k, v in t.items()}) for t in res['toks']] | |
| return (res, code2md(code, COLAB_INSTALL, COLAB_NAME), vis) | |
| interface( | |
| MarkdownFile('README.md'), | |
| MarkdownFile('query.md'), | |
| Demo( | |
| predict_query, | |
| EX_Q, | |
| [ | |
| gr.Dropdown(choices=['max', 'sum'], value='max', label='Aggregation'), | |
| ], | |
| scale=2/3 | |
| ), | |
| MarkdownFile('doc.md'), | |
| Demo( | |
| predict_doc, | |
| EX_D, | |
| [ | |
| gr.Dropdown(choices=['max', 'sum'], value='max', label='Aggregation'), | |
| ], | |
| scale=2/3 | |
| ), | |
| MarkdownFile('wrapup.md'), | |
| ).launch(share=False) | |