Spaces:
Runtime error
Runtime error
fix to have baseline run from the runs table
Browse files- app.py +53 -20
- data_access.py +12 -11
- load_ground_truth.py +0 -0
- eval_tables.py → scripts/eval_tables.py +0 -0
app.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
import asyncio
|
|
|
|
| 2 |
|
| 3 |
import gradio as gr
|
| 4 |
import pandas as pd
|
| 5 |
-
import logging
|
| 6 |
|
| 7 |
from data_access import get_questions, get_source_finders, get_run_ids, get_baseline_rankers, \
|
| 8 |
get_unified_sources, get_source_text, calculate_cumulative_statistics_for_all_questions, get_metadata, \
|
|
@@ -10,6 +10,8 @@ from data_access import get_questions, get_source_finders, get_run_ids, get_base
|
|
| 10 |
|
| 11 |
logger = logging.getLogger(__name__)
|
| 12 |
|
|
|
|
|
|
|
| 13 |
# Initialize data at the module level
|
| 14 |
questions = []
|
| 15 |
source_finders = []
|
|
@@ -22,9 +24,11 @@ run_ids = []
|
|
| 22 |
available_run_id_dict = {}
|
| 23 |
finder_options = []
|
| 24 |
previous_run_id = "initial_run"
|
|
|
|
| 25 |
|
| 26 |
run_id_dropdown = None
|
| 27 |
|
|
|
|
| 28 |
# Get all questions
|
| 29 |
|
| 30 |
# Initialize data in a single async function
|
|
@@ -36,7 +40,6 @@ async def initialize_data():
|
|
| 36 |
source_finders = await get_source_finders(conn)
|
| 37 |
baseline_rankers = await get_baseline_rankers(conn)
|
| 38 |
|
| 39 |
-
baseline_rankers_dict = {f["name"]: f["id"] for f in baseline_rankers}
|
| 40 |
# Convert to dictionaries for easier lookup
|
| 41 |
questions_dict = {q["text"]: q["id"] for q in questions}
|
| 42 |
baseline_rankers_dict = {f["name"]: f["id"] for f in baseline_rankers}
|
|
@@ -46,9 +49,32 @@ async def initialize_data():
|
|
| 46 |
question_options = [q['text'] for q in questions]
|
| 47 |
finder_options = [s["name"] for s in source_finders]
|
| 48 |
baseline_ranker_options = [b["name"] for b in baseline_rankers]
|
|
|
|
|
|
|
| 49 |
|
|
|
|
|
|
|
| 50 |
|
| 51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
global previous_run_id
|
| 53 |
if evt:
|
| 54 |
logger.info(f"event: {evt.target.elem_id}")
|
|
@@ -70,27 +96,30 @@ async def update_sources_list_async(question_option, source_finder_name, run_id,
|
|
| 70 |
if type(baseline_ranker_name) == list:
|
| 71 |
baseline_ranker_name = baseline_ranker_name[0]
|
| 72 |
|
| 73 |
-
baseline_ranker_id_int = 1 if len(baseline_ranker_name) == 0 else baseline_rankers_dict.get(
|
|
|
|
| 74 |
|
| 75 |
if len(source_finder_name):
|
| 76 |
finder_id_int = source_finders_dict.get(source_finder_name)
|
| 77 |
else:
|
| 78 |
finder_id_int = None
|
| 79 |
|
| 80 |
-
if question_option ==
|
| 81 |
if finder_id_int:
|
| 82 |
if run_id is None:
|
| 83 |
available_run_id_dict = await get_run_ids(conn, finder_id_int)
|
| 84 |
run_id = list(available_run_id_dict.keys())[0]
|
| 85 |
previous_run_id = run_id
|
| 86 |
run_id_int = available_run_id_dict.get(run_id)
|
| 87 |
-
all_stats = await calculate_cumulative_statistics_for_all_questions(conn, run_id_int,
|
|
|
|
| 88 |
|
| 89 |
else:
|
| 90 |
run_id_options = list(available_run_id_dict.keys())
|
| 91 |
all_stats = None
|
| 92 |
run_id_options = list(available_run_id_dict.keys())
|
| 93 |
-
return None, all_stats, gr.Dropdown(choices=run_id_options,
|
|
|
|
| 94 |
|
| 95 |
# Extract question ID from selection
|
| 96 |
question_id = questions_dict.get(question_option)
|
|
@@ -102,8 +131,6 @@ async def update_sources_list_async(question_option, source_finder_name, run_id,
|
|
| 102 |
previous_run_id = run_id
|
| 103 |
run_id_int = available_run_id_dict.get(run_id)
|
| 104 |
|
| 105 |
-
|
| 106 |
-
|
| 107 |
source_runs = None
|
| 108 |
stats = None
|
| 109 |
# Get source runs data
|
|
@@ -116,7 +143,8 @@ async def update_sources_list_async(question_option, source_finder_name, run_id,
|
|
| 116 |
return None, None, run_id_options, "No results found for the selected filters",
|
| 117 |
|
| 118 |
# Format table columns
|
| 119 |
-
columns_to_display = ['sugya_id', 'in_baseline', 'baseline_rank', 'in_source_run', 'source_run_rank',
|
|
|
|
| 120 |
'folio', 'reason']
|
| 121 |
df_display = df[columns_to_display] if all(col in df.columns for col in columns_to_display) else df
|
| 122 |
|
|
@@ -147,6 +175,7 @@ async def handle_row_selection_async(evt: gr.SelectData):
|
|
| 147 |
def handle_row_selection(evt: gr.SelectData):
|
| 148 |
return asyncio.run(handle_row_selection_async(evt))
|
| 149 |
|
|
|
|
| 150 |
# Create Gradio app
|
| 151 |
|
| 152 |
# Ensure we clean up when done
|
|
@@ -162,7 +191,7 @@ async def main():
|
|
| 162 |
with gr.Column(scale=1):
|
| 163 |
# Main content area
|
| 164 |
question_dropdown = gr.Dropdown(
|
| 165 |
-
choices=[
|
| 166 |
label="Select Question",
|
| 167 |
value=None,
|
| 168 |
interactive=True,
|
|
@@ -186,7 +215,7 @@ async def main():
|
|
| 186 |
)
|
| 187 |
with gr.Column(scale=1):
|
| 188 |
run_id_dropdown = gr.Dropdown(
|
| 189 |
-
choices=
|
| 190 |
allow_custom_value=True,
|
| 191 |
label="Run id for Question and source finder",
|
| 192 |
interactive=True,
|
|
@@ -201,7 +230,6 @@ async def main():
|
|
| 201 |
gr.Markdown(f"Total Questions: {len(questions)}")
|
| 202 |
gr.Markdown(f"Source Finders: {len(source_finders)}")
|
| 203 |
|
| 204 |
-
|
| 205 |
with gr.Row():
|
| 206 |
result_text = gr.Markdown("Select a question to view source runs")
|
| 207 |
with gr.Row():
|
|
@@ -221,14 +249,15 @@ async def main():
|
|
| 221 |
metadata_text = gr.TextArea(
|
| 222 |
label="Metadata of Source Finder for Selected Question",
|
| 223 |
elem_id="metadata",
|
| 224 |
-
lines
|
| 225 |
)
|
| 226 |
with gr.Row():
|
| 227 |
gr.Markdown("# Sources Found")
|
| 228 |
with gr.Row():
|
| 229 |
with gr.Column(scale=3):
|
| 230 |
results_table = gr.DataFrame(
|
| 231 |
-
headers=['id', 'tractate', 'folio', 'in_baseline', 'baseline_rank', 'in_source_run',
|
|
|
|
| 232 |
interactive=False
|
| 233 |
)
|
| 234 |
with gr.Column(scale=1):
|
|
@@ -246,8 +275,6 @@ async def main():
|
|
| 246 |
# visible=True
|
| 247 |
# )
|
| 248 |
|
| 249 |
-
|
| 250 |
-
|
| 251 |
# Set up event handlers
|
| 252 |
results_table.select(
|
| 253 |
handle_row_selection,
|
|
@@ -255,15 +282,22 @@ async def main():
|
|
| 255 |
outputs=source_text
|
| 256 |
)
|
| 257 |
|
| 258 |
-
|
| 259 |
update_sources_list,
|
| 260 |
inputs=[question_dropdown, source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown],
|
| 261 |
outputs=[results_table, statistics_table, run_id_dropdown, result_text, metadata_text]
|
|
|
|
| 262 |
)
|
| 263 |
|
| 264 |
-
|
| 265 |
update_sources_list,
|
| 266 |
inputs=[question_dropdown, source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 267 |
# outputs=[run_id_dropdown, results_table, result_text, download_button]
|
| 268 |
outputs=[results_table, statistics_table, run_id_dropdown, result_text, metadata_text]
|
| 269 |
)
|
|
@@ -274,7 +308,6 @@ async def main():
|
|
| 274 |
outputs=[results_table, statistics_table, run_id_dropdown, result_text, metadata_text]
|
| 275 |
)
|
| 276 |
|
| 277 |
-
|
| 278 |
app.queue()
|
| 279 |
app.launch()
|
| 280 |
|
|
|
|
| 1 |
import asyncio
|
| 2 |
+
import logging
|
| 3 |
|
| 4 |
import gradio as gr
|
| 5 |
import pandas as pd
|
|
|
|
| 6 |
|
| 7 |
from data_access import get_questions, get_source_finders, get_run_ids, get_baseline_rankers, \
|
| 8 |
get_unified_sources, get_source_text, calculate_cumulative_statistics_for_all_questions, get_metadata, \
|
|
|
|
| 10 |
|
| 11 |
logger = logging.getLogger(__name__)
|
| 12 |
|
| 13 |
+
ALL_QUESTIONS_STR = "All questions"
|
| 14 |
+
|
| 15 |
# Initialize data at the module level
|
| 16 |
questions = []
|
| 17 |
source_finders = []
|
|
|
|
| 24 |
available_run_id_dict = {}
|
| 25 |
finder_options = []
|
| 26 |
previous_run_id = "initial_run"
|
| 27 |
+
run_id_options = []
|
| 28 |
|
| 29 |
run_id_dropdown = None
|
| 30 |
|
| 31 |
+
|
| 32 |
# Get all questions
|
| 33 |
|
| 34 |
# Initialize data in a single async function
|
|
|
|
| 40 |
source_finders = await get_source_finders(conn)
|
| 41 |
baseline_rankers = await get_baseline_rankers(conn)
|
| 42 |
|
|
|
|
| 43 |
# Convert to dictionaries for easier lookup
|
| 44 |
questions_dict = {q["text"]: q["id"] for q in questions}
|
| 45 |
baseline_rankers_dict = {f["name"]: f["id"] for f in baseline_rankers}
|
|
|
|
| 49 |
question_options = [q['text'] for q in questions]
|
| 50 |
finder_options = [s["name"] for s in source_finders]
|
| 51 |
baseline_ranker_options = [b["name"] for b in baseline_rankers]
|
| 52 |
+
update_run_ids(ALL_QUESTIONS_STR, list(source_finders_dict.keys())[0])
|
| 53 |
+
|
| 54 |
|
| 55 |
+
def update_run_ids(question_option, source_finder_name):
|
| 56 |
+
return asyncio.run(update_run_ids_async(question_option, source_finder_name))
|
| 57 |
|
| 58 |
+
|
| 59 |
+
async def update_run_ids_async(question_option, source_finder_name):
|
| 60 |
+
global previous_run_id, available_run_id_dict, run_id_options
|
| 61 |
+
async with get_async_connection() as conn:
|
| 62 |
+
finder_id_int = source_finders_dict.get(source_finder_name)
|
| 63 |
+
if question_option and question_option != ALL_QUESTIONS_STR:
|
| 64 |
+
question_id = questions_dict.get(question_option)
|
| 65 |
+
available_run_id_dict = await get_run_ids(conn, finder_id_int, question_id)
|
| 66 |
+
else:
|
| 67 |
+
available_run_id_dict = await get_run_ids(conn, finder_id_int)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
run_id = list(available_run_id_dict.keys())[0]
|
| 71 |
+
previous_run_id = run_id
|
| 72 |
+
run_id_options = list(available_run_id_dict.keys())
|
| 73 |
+
return None, None, gr.Dropdown(choices=run_id_options,
|
| 74 |
+
value=run_id), "Select Question to see results", ""
|
| 75 |
+
|
| 76 |
+
def update_sources_list(question_option, source_finder_id, run_id: str, baseline_ranker_id: str,
|
| 77 |
+
evt: gr.EventData = None):
|
| 78 |
global previous_run_id
|
| 79 |
if evt:
|
| 80 |
logger.info(f"event: {evt.target.elem_id}")
|
|
|
|
| 96 |
if type(baseline_ranker_name) == list:
|
| 97 |
baseline_ranker_name = baseline_ranker_name[0]
|
| 98 |
|
| 99 |
+
baseline_ranker_id_int = 1 if len(baseline_ranker_name) == 0 else baseline_rankers_dict.get(
|
| 100 |
+
baseline_ranker_name)
|
| 101 |
|
| 102 |
if len(source_finder_name):
|
| 103 |
finder_id_int = source_finders_dict.get(source_finder_name)
|
| 104 |
else:
|
| 105 |
finder_id_int = None
|
| 106 |
|
| 107 |
+
if question_option == ALL_QUESTIONS_STR:
|
| 108 |
if finder_id_int:
|
| 109 |
if run_id is None:
|
| 110 |
available_run_id_dict = await get_run_ids(conn, finder_id_int)
|
| 111 |
run_id = list(available_run_id_dict.keys())[0]
|
| 112 |
previous_run_id = run_id
|
| 113 |
run_id_int = available_run_id_dict.get(run_id)
|
| 114 |
+
all_stats = await calculate_cumulative_statistics_for_all_questions(conn, run_id_int,
|
| 115 |
+
baseline_ranker_id_int)
|
| 116 |
|
| 117 |
else:
|
| 118 |
run_id_options = list(available_run_id_dict.keys())
|
| 119 |
all_stats = None
|
| 120 |
run_id_options = list(available_run_id_dict.keys())
|
| 121 |
+
return None, all_stats, gr.Dropdown(choices=run_id_options,
|
| 122 |
+
value=run_id), "Select Run Id and source finder to see results", ""
|
| 123 |
|
| 124 |
# Extract question ID from selection
|
| 125 |
question_id = questions_dict.get(question_option)
|
|
|
|
| 131 |
previous_run_id = run_id
|
| 132 |
run_id_int = available_run_id_dict.get(run_id)
|
| 133 |
|
|
|
|
|
|
|
| 134 |
source_runs = None
|
| 135 |
stats = None
|
| 136 |
# Get source runs data
|
|
|
|
| 143 |
return None, None, run_id_options, "No results found for the selected filters",
|
| 144 |
|
| 145 |
# Format table columns
|
| 146 |
+
columns_to_display = ['sugya_id', 'in_baseline', 'baseline_rank', 'in_source_run', 'source_run_rank',
|
| 147 |
+
'tractate',
|
| 148 |
'folio', 'reason']
|
| 149 |
df_display = df[columns_to_display] if all(col in df.columns for col in columns_to_display) else df
|
| 150 |
|
|
|
|
| 175 |
def handle_row_selection(evt: gr.SelectData):
|
| 176 |
return asyncio.run(handle_row_selection_async(evt))
|
| 177 |
|
| 178 |
+
|
| 179 |
# Create Gradio app
|
| 180 |
|
| 181 |
# Ensure we clean up when done
|
|
|
|
| 191 |
with gr.Column(scale=1):
|
| 192 |
# Main content area
|
| 193 |
question_dropdown = gr.Dropdown(
|
| 194 |
+
choices=[ALL_QUESTIONS_STR] + question_options,
|
| 195 |
label="Select Question",
|
| 196 |
value=None,
|
| 197 |
interactive=True,
|
|
|
|
| 215 |
)
|
| 216 |
with gr.Column(scale=1):
|
| 217 |
run_id_dropdown = gr.Dropdown(
|
| 218 |
+
choices=run_id_options,
|
| 219 |
allow_custom_value=True,
|
| 220 |
label="Run id for Question and source finder",
|
| 221 |
interactive=True,
|
|
|
|
| 230 |
gr.Markdown(f"Total Questions: {len(questions)}")
|
| 231 |
gr.Markdown(f"Source Finders: {len(source_finders)}")
|
| 232 |
|
|
|
|
| 233 |
with gr.Row():
|
| 234 |
result_text = gr.Markdown("Select a question to view source runs")
|
| 235 |
with gr.Row():
|
|
|
|
| 249 |
metadata_text = gr.TextArea(
|
| 250 |
label="Metadata of Source Finder for Selected Question",
|
| 251 |
elem_id="metadata",
|
| 252 |
+
lines=2
|
| 253 |
)
|
| 254 |
with gr.Row():
|
| 255 |
gr.Markdown("# Sources Found")
|
| 256 |
with gr.Row():
|
| 257 |
with gr.Column(scale=3):
|
| 258 |
results_table = gr.DataFrame(
|
| 259 |
+
headers=['id', 'tractate', 'folio', 'in_baseline', 'baseline_rank', 'in_source_run',
|
| 260 |
+
'source_run_rank', 'source_reason', 'metadata'],
|
| 261 |
interactive=False
|
| 262 |
)
|
| 263 |
with gr.Column(scale=1):
|
|
|
|
| 275 |
# visible=True
|
| 276 |
# )
|
| 277 |
|
|
|
|
|
|
|
| 278 |
# Set up event handlers
|
| 279 |
results_table.select(
|
| 280 |
handle_row_selection,
|
|
|
|
| 282 |
outputs=source_text
|
| 283 |
)
|
| 284 |
|
| 285 |
+
baseline_rankers_dropdown.change(
|
| 286 |
update_sources_list,
|
| 287 |
inputs=[question_dropdown, source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown],
|
| 288 |
outputs=[results_table, statistics_table, run_id_dropdown, result_text, metadata_text]
|
| 289 |
+
|
| 290 |
)
|
| 291 |
|
| 292 |
+
question_dropdown.change(
|
| 293 |
update_sources_list,
|
| 294 |
inputs=[question_dropdown, source_finder_dropdown, run_id_dropdown, baseline_rankers_dropdown],
|
| 295 |
+
outputs=[results_table, statistics_table, run_id_dropdown, result_text, metadata_text]
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
source_finder_dropdown.change(
|
| 299 |
+
update_run_ids,
|
| 300 |
+
inputs=[question_dropdown, source_finder_dropdown],
|
| 301 |
# outputs=[run_id_dropdown, results_table, result_text, download_button]
|
| 302 |
outputs=[results_table, statistics_table, run_id_dropdown, result_text, metadata_text]
|
| 303 |
)
|
|
|
|
| 308 |
outputs=[results_table, statistics_table, run_id_dropdown, result_text, metadata_text]
|
| 309 |
)
|
| 310 |
|
|
|
|
| 311 |
app.queue()
|
| 312 |
app.launch()
|
| 313 |
|
data_access.py
CHANGED
|
@@ -15,6 +15,7 @@ load_dotenv()
|
|
| 15 |
@asynccontextmanager
|
| 16 |
async def get_async_connection(schema="talmudexplore"):
|
| 17 |
"""Get a connection for the current request."""
|
|
|
|
| 18 |
try:
|
| 19 |
# Create a single connection without relying on a shared pool
|
| 20 |
conn = await asyncpg.connect(
|
|
@@ -27,7 +28,8 @@ async def get_async_connection(schema="talmudexplore"):
|
|
| 27 |
await conn.execute(f'SET search_path TO {schema}')
|
| 28 |
yield conn
|
| 29 |
finally:
|
| 30 |
-
|
|
|
|
| 31 |
|
| 32 |
|
| 33 |
async def get_questions(conn: asyncpg.Connection):
|
|
@@ -73,8 +75,13 @@ async def get_run_ids(conn: asyncpg.Connection, source_finder_id: int, question_
|
|
| 73 |
|
| 74 |
|
| 75 |
async def get_baseline_rankers(conn: asyncpg.Connection):
|
| 76 |
-
|
| 77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
async def calculate_baseline_vs_source_stats_for_question(conn: asyncpg.Connection, baseline_sources , source_runs_sources):
|
| 80 |
# for a given question_id and source_finder_id and run_id calculate the baseline vs source stats
|
|
@@ -203,14 +210,8 @@ async def get_unified_sources(conn: asyncpg.Connection, question_id: int, source
|
|
| 203 |
"""
|
| 204 |
source_runs = await conn.fetch(query_runs, question_id, source_finder_run_id)
|
| 205 |
# Get sources from baseline_sources
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
FROM baseline_sources bs
|
| 209 |
-
join talmud_bavli tb on bs.sugya_id = tb.xml_id
|
| 210 |
-
WHERE bs.question_id = $1
|
| 211 |
-
AND bs.ranker_id = $2
|
| 212 |
-
"""
|
| 213 |
-
baseline_sources = await conn.fetch(query_baseline, question_id, ranker_id)
|
| 214 |
stats_df = await calculate_baseline_vs_source_stats_for_question(conn, baseline_sources, source_runs)
|
| 215 |
# Convert to dictionaries for easier lookup
|
| 216 |
source_runs_dict = {s["id"]: dict(s) for s in source_runs}
|
|
|
|
| 15 |
@asynccontextmanager
|
| 16 |
async def get_async_connection(schema="talmudexplore"):
|
| 17 |
"""Get a connection for the current request."""
|
| 18 |
+
conn = None
|
| 19 |
try:
|
| 20 |
# Create a single connection without relying on a shared pool
|
| 21 |
conn = await asyncpg.connect(
|
|
|
|
| 28 |
await conn.execute(f'SET search_path TO {schema}')
|
| 29 |
yield conn
|
| 30 |
finally:
|
| 31 |
+
if conn:
|
| 32 |
+
await conn.close()
|
| 33 |
|
| 34 |
|
| 35 |
async def get_questions(conn: asyncpg.Connection):
|
|
|
|
| 75 |
|
| 76 |
|
| 77 |
async def get_baseline_rankers(conn: asyncpg.Connection):
|
| 78 |
+
query = """
|
| 79 |
+
select sfr.id, sf.source_finder_type, sfr.description from talmudexplore.source_finder_runs sfr
|
| 80 |
+
join source_finders sf on sf.id = sfr.source_finder_id
|
| 81 |
+
order by sf.id
|
| 82 |
+
"""
|
| 83 |
+
rankers = await conn.fetch(query)
|
| 84 |
+
return [{"id": r["id"], "name": f"{r['source_finder_type']} : {r['description']}"} for r in rankers]
|
| 85 |
|
| 86 |
async def calculate_baseline_vs_source_stats_for_question(conn: asyncpg.Connection, baseline_sources , source_runs_sources):
|
| 87 |
# for a given question_id and source_finder_id and run_id calculate the baseline vs source stats
|
|
|
|
| 210 |
"""
|
| 211 |
source_runs = await conn.fetch(query_runs, question_id, source_finder_run_id)
|
| 212 |
# Get sources from baseline_sources
|
| 213 |
+
baseline_query = query_runs.replace("source_rank", "baseline_rank")
|
| 214 |
+
baseline_sources = await conn.fetch(baseline_query, question_id, ranker_id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
stats_df = await calculate_baseline_vs_source_stats_for_question(conn, baseline_sources, source_runs)
|
| 216 |
# Convert to dictionaries for easier lookup
|
| 217 |
source_runs_dict = {s["id"]: dict(s) for s in source_runs}
|
load_ground_truth.py
DELETED
|
File without changes
|
eval_tables.py → scripts/eval_tables.py
RENAMED
|
File without changes
|