Simon Clematide commited on
Commit
f526e5a
·
1 Parent(s): 6c82aa2

Enhance data handling and UI for OCRQA exploration tool

Browse files

- Added error handling for empty dataframes.
- Improved ranking calculations for newspapers.
- Updated newspaper selection logic to include random choices.
- Enhanced UI components for better user interaction.

Files changed (1) hide show
  1. app.py +170 -32
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import json
 
2
  from urllib.request import urlopen
3
 
4
  import pandas as pd
@@ -51,31 +52,117 @@ for media in data.get("media_list", []):
51
 
52
  df = pd.DataFrame(rows).sort_values(["provider", "newspaper", "year"])
53
 
 
 
 
54
  provider_options = ["All"] + sorted(df["provider"].dropna().unique().tolist())
55
 
 
 
 
 
 
 
 
 
 
56
 
57
- def newspapers_for_provider(provider):
58
- subset = df if provider == "All" else df[df["provider"] == provider]
59
- ranking = (
60
- subset.groupby("newspaper", as_index=False)["avg_ocrqa"]
61
- .mean()
62
- .sort_values("avg_ocrqa", ascending=False)
63
- )
64
- return ranking["newspaper"].tolist()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
 
 
66
 
67
- def update_newspaper_choices(provider):
68
- choices = newspapers_for_provider(provider)
69
- return gr.update(choices=choices, value=choices[:10])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
 
72
  def make_plot(provider, selected_newspapers):
73
- subset = df if provider == "All" else df[df["provider"] == provider]
 
 
 
 
 
 
 
 
 
 
 
 
74
  subset = subset[subset["newspaper"].isin(selected_newspapers)]
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  fig = go.Figure()
77
 
78
- for newspaper in selected_newspapers:
79
  dfn = subset[subset["newspaper"] == newspaper].sort_values("year")
80
  if dfn.empty:
81
  continue
@@ -109,34 +196,85 @@ def make_plot(provider, selected_newspapers):
109
  return fig
110
 
111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  with gr.Blocks() as demo:
113
  gr.Markdown("## OCRQA exploration")
114
 
115
- provider = gr.Dropdown(
116
- choices=provider_options,
117
- value="All",
118
- label="Provider",
119
- )
 
 
 
 
 
 
120
 
121
- newspaper = gr.CheckboxGroup(
122
- choices=newspapers_for_provider("All"),
123
- value=newspapers_for_provider("All")[:10],
124
- label="Newspapers (ranked by mean OCRQA)",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  )
126
 
127
  plot = gr.Plot()
128
 
129
- provider.change(
130
- fn=update_newspaper_choices,
131
- inputs=provider,
132
- outputs=newspaper,
133
- )
134
 
135
- provider.change(
136
- fn=make_plot,
137
- inputs=[provider, newspaper],
138
- outputs=plot,
139
- )
 
 
 
 
 
 
140
 
141
  newspaper.change(
142
  fn=make_plot,
 
1
  import json
2
+ import random
3
  from urllib.request import urlopen
4
 
5
  import pandas as pd
 
52
 
53
  df = pd.DataFrame(rows).sort_values(["provider", "newspaper", "year"])
54
 
55
+ if df.empty:
56
+ raise ValueError("No yearly OCRQA data found.")
57
+
58
  provider_options = ["All"] + sorted(df["provider"].dropna().unique().tolist())
59
 
60
+ # -------------------------------------------------------------------
61
+ # Rankings
62
+ # -------------------------------------------------------------------
63
+
64
+ ranking_by_provider = (
65
+ df.groupby(["provider", "newspaper"], as_index=False)["avg_ocrqa"]
66
+ .mean()
67
+ .rename(columns={"avg_ocrqa": "mean_ocrqa"})
68
+ )
69
 
70
+ ranking_global = (
71
+ df.groupby("newspaper", as_index=False)["avg_ocrqa"]
72
+ .mean()
73
+ .rename(columns={"avg_ocrqa": "mean_ocrqa"})
74
+ )
75
+
76
+
77
+ def get_ranked_df(provider="All", query=""):
78
+ if provider == "All":
79
+ ranked = ranking_global.copy()
80
+ else:
81
+ ranked = ranking_by_provider.loc[
82
+ ranking_by_provider["provider"] == provider, ["newspaper", "mean_ocrqa"]
83
+ ].copy()
84
+
85
+ ranked = ranked.sort_values(
86
+ ["mean_ocrqa", "newspaper"], ascending=[False, True]
87
+ ).reset_index(drop=True)
88
+
89
+ if query:
90
+ q = query.strip().lower()
91
+ ranked = ranked[
92
+ ranked["newspaper"].str.lower().str.contains(q, na=False)
93
+ ].reset_index(drop=True)
94
+
95
+ return ranked
96
+
97
+
98
+ def choose_newspapers(ranked, n_best, n_worst, n_random, seed=13):
99
+ ranked_names = ranked["newspaper"].tolist()
100
 
101
+ best = ranked_names[: int(n_best)] if n_best > 0 else []
102
+ worst = ranked_names[-int(n_worst) :] if n_worst > 0 else []
103
 
104
+ remaining_for_random = [
105
+ n for n in ranked_names if n not in set(best) and n not in set(worst)
106
+ ]
107
+
108
+ rng = random.Random(seed)
109
+ n_random = min(int(n_random), len(remaining_for_random))
110
+ random_pick = rng.sample(remaining_for_random, n_random) if n_random > 0 else []
111
+
112
+ selected = best + worst + random_pick
113
+
114
+ # Deduplicate while preserving order
115
+ selected = list(dict.fromkeys(selected))
116
+
117
+ # Choices should remain OCRQA-ranked, not in selection order
118
+ choices = ranked_names
119
+
120
+ return choices, selected
121
+
122
+
123
+ def update_newspapers(provider, query, n_best, n_worst, n_random):
124
+ ranked = get_ranked_df(provider, query)
125
+ choices, selected = choose_newspapers(ranked, n_best, n_worst, n_random)
126
+ return gr.update(choices=choices, value=selected)
127
 
128
 
129
  def make_plot(provider, selected_newspapers):
130
+ if not selected_newspapers:
131
+ fig = go.Figure()
132
+ fig.update_layout(
133
+ title="Select one or more newspapers",
134
+ xaxis_title="Year",
135
+ yaxis_title="Average OCRQA",
136
+ yaxis=dict(range=[0, 1.05]),
137
+ template="plotly_white",
138
+ height=650,
139
+ )
140
+ return fig
141
+
142
+ subset = df.copy() if provider == "All" else df[df["provider"] == provider].copy()
143
  subset = subset[subset["newspaper"].isin(selected_newspapers)]
144
 
145
+ if subset.empty:
146
+ fig = go.Figure()
147
+ fig.update_layout(
148
+ title="No data for the current selection",
149
+ xaxis_title="Year",
150
+ yaxis_title="Average OCRQA",
151
+ yaxis=dict(range=[0, 1.05]),
152
+ template="plotly_white",
153
+ height=650,
154
+ )
155
+ return fig
156
+
157
+ # Preserve ranking order in legend/traces
158
+ ranked = get_ranked_df(provider, "")
159
+ ranked_order = [
160
+ n for n in ranked["newspaper"].tolist() if n in set(selected_newspapers)
161
+ ]
162
+
163
  fig = go.Figure()
164
 
165
+ for newspaper in ranked_order:
166
  dfn = subset[subset["newspaper"] == newspaper].sort_values("year")
167
  if dfn.empty:
168
  continue
 
196
  return fig
197
 
198
 
199
+ # -------------------------------------------------------------------
200
+ # Initial state
201
+ # -------------------------------------------------------------------
202
+
203
+ initial_provider = "All"
204
+ initial_query = ""
205
+ initial_best = 10
206
+ initial_worst = 0
207
+ initial_random = 0
208
+
209
+ initial_ranked = get_ranked_df(initial_provider, initial_query)
210
+ initial_choices, initial_selected = choose_newspapers(
211
+ initial_ranked, initial_best, initial_worst, initial_random
212
+ )
213
+
214
+ # -------------------------------------------------------------------
215
+ # UI
216
+ # -------------------------------------------------------------------
217
+
218
  with gr.Blocks() as demo:
219
  gr.Markdown("## OCRQA exploration")
220
 
221
+ with gr.Row():
222
+ provider = gr.Dropdown(
223
+ choices=provider_options,
224
+ value=initial_provider,
225
+ label="Provider",
226
+ )
227
+ query = gr.Textbox(
228
+ value=initial_query,
229
+ label="Filter newspapers",
230
+ placeholder="Type part of a newspaper title",
231
+ )
232
 
233
+ with gr.Row():
234
+ n_best = gr.Slider(
235
+ minimum=0,
236
+ maximum=400,
237
+ value=initial_best,
238
+ step=1,
239
+ label="Best OCRQA",
240
+ )
241
+ n_worst = gr.Slider(
242
+ minimum=0,
243
+ maximum=400,
244
+ value=initial_worst,
245
+ step=1,
246
+ label="Worst OCRQA",
247
+ )
248
+ n_random = gr.Slider(
249
+ minimum=0,
250
+ maximum=400,
251
+ value=initial_random,
252
+ step=1,
253
+ label="Random OCRQA",
254
+ )
255
+
256
+ newspaper = gr.Dropdown(
257
+ choices=initial_choices,
258
+ value=initial_selected,
259
+ multiselect=True,
260
+ label="Newspapers (filtered and ranked)",
261
  )
262
 
263
  plot = gr.Plot()
264
 
265
+ selector_inputs = [provider, query, n_best, n_worst, n_random]
 
 
 
 
266
 
267
+ for trigger in selector_inputs:
268
+ trigger.change(
269
+ fn=update_newspapers,
270
+ inputs=selector_inputs,
271
+ outputs=newspaper,
272
+ )
273
+ trigger.change(
274
+ fn=lambda provider, newspaper: make_plot(provider, newspaper),
275
+ inputs=[provider, newspaper],
276
+ outputs=plot,
277
+ )
278
 
279
  newspaper.change(
280
  fn=make_plot,