Spaces:
Sleeping
Sleeping
| import json | |
| import os | |
| import random | |
| from urllib.request import urlopen | |
| import pandas as pd | |
| import plotly.graph_objects as go | |
| import gradio as gr | |
| # ------------------------------------------------------------------- | |
| # Load data | |
| # ------------------------------------------------------------------- | |
| DATA_SOURCE = "https://os.zhdk.cloud.switch.ch/115-canonical-processed-final/langident/langident-lid-ensemble_multilingual_v2-0-2/langid-ocrqa_v2-0-0.json" | |
| with urlopen(DATA_SOURCE) as response: | |
| data = json.load(response) | |
| # ------------------------------------------------------------------- | |
| # Flatten yearly OCRQA data | |
| # ------------------------------------------------------------------- | |
| rows = [] | |
| for media in data.get("media_list", []): | |
| provider = media.get("data_provider") | |
| newspaper = media.get("media_title") | |
| for stats in media.get("media_statistics", []): | |
| if stats.get("granularity") != "year": | |
| continue | |
| try: | |
| year = int(stats["element"].rsplit("-", 1)[-1]) | |
| except Exception: | |
| continue | |
| media_stats = stats.get("media_stats", {}) | |
| avg_ocrqa = media_stats.get("avg_ocrqa") | |
| if avg_ocrqa is None: | |
| continue | |
| rows.append( | |
| { | |
| "provider": provider, | |
| "newspaper": newspaper, | |
| "year": year, | |
| "avg_ocrqa": avg_ocrqa, | |
| "issues": media_stats.get("issues"), | |
| "content_items_out": media_stats.get("content_items_out"), | |
| } | |
| ) | |
| df = pd.DataFrame(rows).sort_values(["provider", "newspaper", "year"]) | |
| if df.empty: | |
| raise ValueError("No yearly OCRQA data found.") | |
| # ------------------------------------------------------------------- | |
| # Alias lookups (ALL-ALIAS.jsonl) | |
| # ------------------------------------------------------------------- | |
| media_title_map: dict[str, str] = {} # media_alias → full title | |
| provider_name_map: dict[str, str] = {} # provider_alias → full name | |
| _alias_path = os.path.join( | |
| os.path.dirname(os.path.abspath(__file__)), "ALL-ALIAS.jsonl" | |
| ) | |
| with open(_alias_path, encoding="utf-8") as _f: | |
| for _line in _f: | |
| _line = _line.strip() | |
| if not _line: | |
| continue | |
| _entry = json.loads(_line) | |
| media_title_map[_entry["media_alias"].strip()] = _entry["media_title"] | |
| _pa = _entry["provider_alias"].strip() | |
| if _pa not in provider_name_map: | |
| provider_name_map[_pa] = _entry["provider_name"] | |
| def newspaper_label(alias: str) -> str: | |
| alias = alias.strip() | |
| title = media_title_map.get(alias, alias) | |
| return f"{title} [{alias}]" if title != alias else alias | |
| def provider_label(alias: str) -> str: | |
| alias = alias.strip() | |
| name = provider_name_map.get(alias, alias) | |
| # Strip a trailing "(ALIAS)" already embedded in provider_name | |
| suffix = f"({alias})" | |
| if name.endswith(suffix): | |
| name = name[: -len(suffix)].strip() | |
| return f"{name} [{alias}]" | |
| provider_options = [("All", "All")] + sorted( | |
| [(provider_label(p), p) for p in df["provider"].dropna().unique()], | |
| key=lambda x: x[0], | |
| ) | |
| # ------------------------------------------------------------------- | |
| # Rankings | |
| # ------------------------------------------------------------------- | |
| ranking_by_provider = ( | |
| df.groupby(["provider", "newspaper"], as_index=False)["avg_ocrqa"] | |
| .mean() | |
| .rename(columns={"avg_ocrqa": "mean_ocrqa"}) | |
| ) | |
| ranking_global = ( | |
| df.groupby("newspaper", as_index=False)["avg_ocrqa"] | |
| .mean() | |
| .rename(columns={"avg_ocrqa": "mean_ocrqa"}) | |
| ) | |
| def get_ranked_df(provider="All", query=""): | |
| if provider == "All": | |
| ranked = ranking_global.copy() | |
| else: | |
| ranked = ranking_by_provider.loc[ | |
| ranking_by_provider["provider"] == provider, ["newspaper", "mean_ocrqa"] | |
| ].copy() | |
| ranked = ranked.sort_values( | |
| ["mean_ocrqa", "newspaper"], ascending=[False, True] | |
| ).reset_index(drop=True) | |
| if query: | |
| q = query.strip() | |
| def _matches(alias: str) -> bool: | |
| if q in alias: | |
| return True | |
| return q in media_title_map.get(alias.strip(), "") | |
| ranked = ranked[ranked["newspaper"].apply(_matches)].reset_index(drop=True) | |
| return ranked | |
| def choose_newspapers(ranked, n_best, n_worst, n_random, seed=13): | |
| ranked_names = ranked["newspaper"].tolist() | |
| best = ranked_names[: int(n_best)] if n_best > 0 else [] | |
| worst = ranked_names[-int(n_worst) :] if n_worst > 0 else [] | |
| remaining_for_random = [ | |
| n for n in ranked_names if n not in set(best) and n not in set(worst) | |
| ] | |
| rng = random.Random(seed) | |
| n_random = min(int(n_random), len(remaining_for_random)) | |
| random_pick = rng.sample(remaining_for_random, n_random) if n_random > 0 else [] | |
| selected = best + worst + random_pick | |
| # Deduplicate while preserving order | |
| selected = list(dict.fromkeys(selected)) | |
| # Choices should remain OCRQA-ranked, not in selection order | |
| choices = ranked_names | |
| return choices, selected | |
| def update_newspapers(provider, query, n_best, n_worst, n_random): | |
| ranked = get_ranked_df(provider, query) | |
| choices, selected = choose_newspapers(ranked, n_best, n_worst, n_random) | |
| labeled_choices = [(newspaper_label(n), n) for n in choices] | |
| return gr.update(choices=labeled_choices, value=selected) | |
| def make_plot(provider, selected_newspapers): | |
| if not selected_newspapers: | |
| fig = go.Figure() | |
| fig.update_layout( | |
| title="Select one or more newspapers", | |
| xaxis_title="Year", | |
| yaxis_title="Average OCRQA", | |
| yaxis=dict(range=[0, 1.05]), | |
| template="plotly_white", | |
| height=650, | |
| ) | |
| return fig | |
| subset = df.copy() if provider == "All" else df[df["provider"] == provider].copy() | |
| subset = subset[subset["newspaper"].isin(selected_newspapers)] | |
| if subset.empty: | |
| fig = go.Figure() | |
| fig.update_layout( | |
| title="No data for the current selection", | |
| xaxis_title="Year", | |
| yaxis_title="Average OCRQA", | |
| yaxis=dict(range=[0, 1.05]), | |
| template="plotly_white", | |
| height=650, | |
| ) | |
| return fig | |
| # Preserve ranking order in legend/traces | |
| ranked = get_ranked_df(provider, "") | |
| ranked_order = [ | |
| n for n in ranked["newspaper"].tolist() if n in set(selected_newspapers) | |
| ] | |
| fig = go.Figure() | |
| for newspaper in ranked_order: | |
| dfn = subset[subset["newspaper"] == newspaper].sort_values("year") | |
| if dfn.empty: | |
| continue | |
| fig.add_trace( | |
| go.Scatter( | |
| x=dfn["year"], | |
| y=dfn["avg_ocrqa"], | |
| mode="markers", | |
| name=newspaper_label(newspaper), | |
| customdata=dfn[["issues", "content_items_out"]].values, | |
| hovertemplate=( | |
| "<b>%{fullData.name}</b><br>" | |
| "Year: %{x}<br>" | |
| "Average OCRQA: %{y:.3f}<br>" | |
| "Issues: %{customdata[0]}<br>" | |
| "Content items: %{customdata[1]}" | |
| "<extra></extra>" | |
| ), | |
| ) | |
| ) | |
| year_min = subset["year"].min() | |
| year_max = subset["year"].max() | |
| if year_max - year_min < 10: | |
| mid = (year_min + year_max) / 2 | |
| year_min = int(mid - 5) | |
| year_max = int(mid + 5) | |
| provider_display = provider if provider == "All" else provider_label(provider) | |
| fig.update_layout( | |
| title=f"OCRQA by newspaper — provider: {provider_display}", | |
| xaxis_title="Year", | |
| xaxis=dict(range=[year_min - 1, year_max + 1]), | |
| yaxis_title="Average OCRQA", | |
| yaxis=dict(range=[0, 1.05]), | |
| template="plotly_white", | |
| height=650, | |
| ) | |
| return fig | |
| # ------------------------------------------------------------------- | |
| # Initial state | |
| # ------------------------------------------------------------------- | |
| initial_provider = "All" | |
| initial_query = "" | |
| initial_best = 10 | |
| initial_worst = 0 | |
| initial_random = 0 | |
| initial_ranked = get_ranked_df(initial_provider, initial_query) | |
| initial_choices, initial_selected = choose_newspapers( | |
| initial_ranked, initial_best, initial_worst, initial_random | |
| ) | |
| # ------------------------------------------------------------------- | |
| # UI | |
| # ------------------------------------------------------------------- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## OCR Quality Assessment exploration") | |
| gr.Markdown( | |
| "For details on how OCRQA scores are computed, see the <a" | |
| ' href="https://huggingface.co/spaces/impresso-project/ocrqa-demo"' | |
| ' target="_blank">OCRQA demo</a>.' | |
| ) | |
| with gr.Row(): | |
| provider = gr.Dropdown( | |
| choices=provider_options, | |
| value=initial_provider, | |
| label="Provider", | |
| ) | |
| query = gr.Textbox( | |
| value=initial_query, | |
| label="Filter newspapers (case-sensitive)", | |
| placeholder="Type a newspaper title", | |
| ) | |
| with gr.Row(): | |
| n_best = gr.Slider( | |
| minimum=0, | |
| maximum=400, | |
| value=initial_best, | |
| step=1, | |
| label="Best OCRQA", | |
| ) | |
| n_worst = gr.Slider( | |
| minimum=0, | |
| maximum=400, | |
| value=initial_worst, | |
| step=1, | |
| label="Worst OCRQA", | |
| ) | |
| n_random = gr.Slider( | |
| minimum=0, | |
| maximum=400, | |
| value=initial_random, | |
| step=1, | |
| label="Random OCRQA", | |
| ) | |
| newspaper = gr.Dropdown( | |
| choices=[(newspaper_label(n), n) for n in initial_choices], | |
| value=initial_selected, | |
| multiselect=True, | |
| label="Newspapers (filtered and ranked)", | |
| ) | |
| plot = gr.Plot() | |
| selector_inputs = [provider, query, n_best, n_worst, n_random] | |
| for trigger in selector_inputs: | |
| trigger.change( | |
| fn=update_newspapers, | |
| inputs=selector_inputs, | |
| outputs=newspaper, | |
| ) | |
| trigger.change( | |
| fn=lambda provider, newspaper: make_plot(provider, newspaper), | |
| inputs=[provider, newspaper], | |
| outputs=plot, | |
| ) | |
| newspaper.change( | |
| fn=make_plot, | |
| inputs=[provider, newspaper], | |
| outputs=plot, | |
| ) | |
| demo.load( | |
| fn=make_plot, | |
| inputs=[provider, newspaper], | |
| outputs=plot, | |
| ) | |
| demo.launch(ssr_mode=False) | |