Spaces:
Runtime error
Runtime error
Sean MacAvaney
commited on
Commit
·
a8fb57f
1
Parent(s):
492a748
more settings
Browse files
app.py
CHANGED
|
@@ -15,6 +15,7 @@ COLORS = ['rgb(252, 132, 100)','rgb(252, 148, 116)','rgb(252, 166, 137)','rgb(25
|
|
| 15 |
doc2query = Doc2Query(MODEL, append=True, num_samples=5)
|
| 16 |
electra = ElectraScorer()
|
| 17 |
query_scorer = QueryScorer(electra)
|
|
|
|
| 18 |
|
| 19 |
COLAB_NAME = 'pyterrier_doc2query.ipynb'
|
| 20 |
COLAB_INSTALL = '''
|
|
@@ -75,13 +76,28 @@ def generate_vis(df):
|
|
| 75 |
''')
|
| 76 |
return '\n'.join(result)
|
| 77 |
|
| 78 |
-
def predict_mm(input, model, num_samples, score_model):
|
| 79 |
assert model == MODEL
|
| 80 |
assert score_model == SCORE_MODEL
|
| 81 |
doc2query.append = False
|
| 82 |
doc2query.num_samples = num_samples
|
| 83 |
-
|
| 84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
import pandas as pd
|
| 86 |
from pyterrier_doc2query import Doc2Query, QueryScorer
|
| 87 |
from pyterrier_dr import ElectraScorer
|
|
@@ -138,8 +154,14 @@ interface(
|
|
| 138 |
), gr.Dropdown(
|
| 139 |
choices=[SCORE_MODEL],
|
| 140 |
value=SCORE_MODEL,
|
| 141 |
-
label='
|
| 142 |
interactive=False,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
)],
|
| 144 |
),
|
| 145 |
MarkdownFile('wrapup.md'),
|
|
|
|
| 15 |
doc2query = Doc2Query(MODEL, append=True, num_samples=5)
|
| 16 |
electra = ElectraScorer()
|
| 17 |
query_scorer = QueryScorer(electra)
|
| 18 |
+
query_filter = QueryFilter(p=0.5, append=False)
|
| 19 |
|
| 20 |
COLAB_NAME = 'pyterrier_doc2query.ipynb'
|
| 21 |
COLAB_INSTALL = '''
|
|
|
|
| 76 |
''')
|
| 77 |
return '\n'.join(result)
|
| 78 |
|
| 79 |
+
def predict_mm(input, model, num_samples, score_model, filter_pct):
|
| 80 |
assert model == MODEL
|
| 81 |
assert score_model == SCORE_MODEL
|
| 82 |
doc2query.append = False
|
| 83 |
doc2query.num_samples = num_samples
|
| 84 |
+
if filter_pct > 0:
|
| 85 |
+
query_filter.t = PERCENTILES_BY_5[filter_pct//5-1]
|
| 86 |
+
pipeline = doc2query >> query_scorer >> query_filter
|
| 87 |
+
code = f'''import pyterrier as pt ; pt.init()
|
| 88 |
+
import pandas as pd
|
| 89 |
+
from pyterrier_doc2query import Doc2Query, QueryScorer, QueryFilter
|
| 90 |
+
from pyterrier_dr import ElectraScorer
|
| 91 |
+
|
| 92 |
+
doc2query = Doc2Query({repr(model)}, append=False, num_samples={num_samples})
|
| 93 |
+
scorer = ElectraScorer({repr(score_model)})
|
| 94 |
+
pipeline = doc2query >> QueryScorer(scorer) >> QueryFilter(t={query_filter.t})
|
| 95 |
+
|
| 96 |
+
pipeline({df2code(input)})
|
| 97 |
+
'''
|
| 98 |
+
else:
|
| 99 |
+
pipeline = doc2query >> query_scorer
|
| 100 |
+
code = f'''import pyterrier as pt ; pt.init()
|
| 101 |
import pandas as pd
|
| 102 |
from pyterrier_doc2query import Doc2Query, QueryScorer
|
| 103 |
from pyterrier_dr import ElectraScorer
|
|
|
|
| 154 |
), gr.Dropdown(
|
| 155 |
choices=[SCORE_MODEL],
|
| 156 |
value=SCORE_MODEL,
|
| 157 |
+
label='Scorer',
|
| 158 |
interactive=False,
|
| 159 |
+
), gr.Slider(
|
| 160 |
+
minimum=0,
|
| 161 |
+
maximum=95,
|
| 162 |
+
value=10,
|
| 163 |
+
step=5,
|
| 164 |
+
label='Filter (top % of queries)'
|
| 165 |
)],
|
| 166 |
),
|
| 167 |
MarkdownFile('wrapup.md'),
|