Spaces:
Paused
Paused
Sean MacAvaney commited on
Commit ·
1051b11
1
Parent(s): 68730a3
fixups: base64 query components and quoting agg in code sample
Browse files
app.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import re
|
| 2 |
import json
|
| 3 |
import pandas as pd
|
|
@@ -23,7 +24,13 @@ def generate_vis(df, mode='Document'):
|
|
| 23 |
max_score = max(max(t.values()) for t in df['toks'])
|
| 24 |
for row in df.itertuples(index=False):
|
| 25 |
if mode == 'Query':
|
| 26 |
-
tok_scores = {m.group(2): float(m.group(1)) for m in re.finditer(r'combine:0=([0-9.]+)\(([^)]+)\)', row.query)}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
max_score = max(tok_scores.values())
|
| 28 |
orig_tokens = factory_max.tokenizer.tokenize(row.query_0)
|
| 29 |
id = row.qid
|
|
@@ -51,7 +58,7 @@ def predict_query(input, agg):
|
|
| 51 |
import pyterrier as pt ; pt.init()
|
| 52 |
import pyt_splade
|
| 53 |
|
| 54 |
-
factory = pyt_splade.SpladeFactory(agg={agg})
|
| 55 |
|
| 56 |
query_pipeline = factory.query()
|
| 57 |
|
|
@@ -70,7 +77,7 @@ def predict_doc(input, agg):
|
|
| 70 |
import pyterrier as pt ; pt.init()
|
| 71 |
import pyt_splade
|
| 72 |
|
| 73 |
-
factory = pyt_splade.SpladeFactory(agg={agg})
|
| 74 |
|
| 75 |
doc_pipeline = factory.indexing()
|
| 76 |
|
|
|
|
| 1 |
+
import base64
|
| 2 |
import re
|
| 3 |
import json
|
| 4 |
import pandas as pd
|
|
|
|
| 24 |
max_score = max(max(t.values()) for t in df['toks'])
|
| 25 |
for row in df.itertuples(index=False):
|
| 26 |
if mode == 'Query':
|
| 27 |
+
tok_scores = {m.group(2): float(m.group(1)) for m in re.finditer(r'#combine:0=([0-9.]+)\((#base64\([^)]+\)|[^)]+)\)', row.query)}
|
| 28 |
+
for key, value in list(tok_scores.items()):
|
| 29 |
+
if key.startswith('#base64('):
|
| 30 |
+
b64 = re.search('#base64\(([^)]+)\)', key).group(1)
|
| 31 |
+
del tok_scores[key]
|
| 32 |
+
key = base64.b64decode(b64).decode()
|
| 33 |
+
tok_scores[key] = value
|
| 34 |
max_score = max(tok_scores.values())
|
| 35 |
orig_tokens = factory_max.tokenizer.tokenize(row.query_0)
|
| 36 |
id = row.qid
|
|
|
|
| 58 |
import pyterrier as pt ; pt.init()
|
| 59 |
import pyt_splade
|
| 60 |
|
| 61 |
+
factory = pyt_splade.SpladeFactory(agg={repr(agg)})
|
| 62 |
|
| 63 |
query_pipeline = factory.query()
|
| 64 |
|
|
|
|
| 77 |
import pyterrier as pt ; pt.init()
|
| 78 |
import pyt_splade
|
| 79 |
|
| 80 |
+
factory = pyt_splade.SpladeFactory(agg={repr(agg)})
|
| 81 |
|
| 82 |
doc_pipeline = factory.indexing()
|
| 83 |
|