davanstrien HF Staff commited on
Commit
163634f
·
verified ·
1 Parent(s): 7b9f964

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. README.md +47 -7
  2. app.py +262 -0
  3. diff_utils.py +39 -0
  4. requirements.txt +4 -0
README.md CHANGED
@@ -1,13 +1,53 @@
1
  ---
2
- title: Diffusiongemma Ocr
3
- emoji: 🏃
4
- colorFrom: purple
5
- colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 6.17.3
8
- python_version: '3.13'
9
  app_file: app.py
10
  pinned: false
 
 
 
 
 
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: DiffusionGemma vs Gemma-4 — Post-OCR Correction
3
+ emoji: 📰
4
+ colorFrom: yellow
5
+ colorTo: red
6
  sdk: gradio
7
+ sdk_version: "5.49.1"
 
8
  app_file: app.py
9
  pinned: false
10
+ license: apache-2.0
11
+ short_description: Diffusion vs autoregressive LLM on historical OCR cleanup
12
+ models:
13
+ - google/diffusiongemma-26B-A4B-it
14
+ - google/gemma-4-E4B-it
15
  ---
16
 
17
+ # DiffusionGemma vs Gemma-4: post-OCR correction
18
+
19
+ A pragmatic first-pass comparison of Google's **experimental diffusion LLM**
20
+ [DiffusionGemma-26B-A4B-it](https://huggingface.co/google/diffusiongemma-26B-A4B-it)
21
+ (released 2026-06-10; 26B MoE, 3.8B active; generates 256-token blocks by iterative
22
+ denoising) against an autoregressive baseline,
23
+ [Gemma-4-E4B-it](https://huggingface.co/google/gemma-4-E4B-it) (~4.5B effective),
24
+ on **post-OCR correction of 19th-century English newspaper text**.
25
+
26
+ **Hypothesis**: a diffusion LM treats correction as denoising, so it may be
27
+ (a) faster and (b) less prone to *over-correction* — rewriting text that was
28
+ already correct — than an autoregressive model, possibly at some accuracy cost.
29
+
30
+ ## Method (v1, pragmatic)
31
+
32
+ - 75 passages from [BLN600](https://doi.org/10.15131/shef.data.25439023)
33
+ (19th-c British Library newspapers, aligned OCR + human gold transcription),
34
+ align-trimmed to ≤220 Gemma tokens so outputs fit DiffusionGemma's single
35
+ 256-token block. Identical prompt for both models; thinking mode off; bf16;
36
+ batch size 1; A100-80GB.
37
+ - Gemma-4 decodes greedily. DiffusionGemma uses its generation-config default
38
+ entropy sampler (**no greedy equivalent exists** for the diffusion sampler —
39
+ this is an unavoidable asymmetry, not a tuning choice).
40
+ - **Over-correction rate**: of input characters that were already correct
41
+ (per input↔gold character alignment), the fraction the model changed
42
+ (per input↔output alignment). **Fix rate**: of input characters that were
43
+ wrong, the fraction the model changed. Text NFC-normalized, whitespace
44
+ collapsed, before all metrics. CER/WER via jiwer.
45
+
46
+ ## Limitations
47
+
48
+ n=75, single prompt, one run (no seeds/significance testing), 256-token block
49
+ caps passage length, tokens/sec for DiffusionGemma is computed over denoising
50
+ the whole block, DiffusionGemma is experimental and one day old at benchmark
51
+ time. Live demo examples are from ICDAR2019 post-OCR (CC-BY-4.0) because
52
+ BLN600's CC-BY-NC license doesn't permit redistribution here; benchmark passage
53
+ texts are likewise not republished — only per-passage metrics.
app.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """DiffusionGemma vs Gemma-4 on post-OCR correction — ZeroGPU comparison Space.
2
+
3
+ Side-by-side correction of 19th-century English newspaper OCR by an
4
+ experimental block-diffusion LLM (google/diffusiongemma-26B-A4B-it) and an
5
+ autoregressive baseline (google/gemma-4-E4B-it).
6
+ """
7
+
8
+ import json
9
+ import os
10
+ import time
11
+ from pathlib import Path
12
+
13
+ import gradio as gr
14
+ import spaces
15
+ import torch
16
+ from transformers import (
17
+ AutoModelForMultimodalLM,
18
+ AutoProcessor,
19
+ DiffusionGemmaForBlockDiffusion,
20
+ TextDiffusionStreamer,
21
+ )
22
+
23
+ from diff_utils import COLOR_MAP, diff_highlight
24
+
25
+ # Keep in sync with benchmark.py PROMPT_TEMPLATE — the benchmark numbers in the
26
+ # results tab were produced with exactly this prompt.
27
+ PROMPT_TEMPLATE = """\
28
+ Correct the OCR errors in the following text from a 19th-century English newspaper.
29
+ Fix only recognition errors (wrong, missing, or extra characters). Do not modernise \
30
+ spelling, do not rephrase, and do not add or remove content. Preserve the original \
31
+ punctuation unless it is clearly an OCR error.
32
+ Output only the corrected text, with no commentary or preamble.
33
+
34
+ OCR text:
35
+ {ocr}"""
36
+
37
+ MAX_INPUT_CHARS = 1200 # roughly the 220-token benchmark cap
38
+
39
+
40
+ def model_path(volume_path: str, model_id: str) -> str:
41
+ """Prefer a mounted hf:// volume (see `hf spaces volumes`) over a download."""
42
+ return volume_path if os.path.isdir(volume_path) else model_id
43
+
44
+
45
+ DG_PATH = model_path("/models/dg", "google/diffusiongemma-26B-A4B-it")
46
+ G4_PATH = model_path("/models/gemma", "google/gemma-4-E4B-it")
47
+
48
+ print(f"loading DiffusionGemma from {DG_PATH} ...")
49
+ dg_processor = AutoProcessor.from_pretrained(DG_PATH)
50
+ dg_model = DiffusionGemmaForBlockDiffusion.from_pretrained(DG_PATH, dtype=torch.bfloat16).to("cuda")
51
+ print(f"loading Gemma-4 from {G4_PATH} ...")
52
+ g4_processor = AutoProcessor.from_pretrained(G4_PATH)
53
+ g4_model = AutoModelForMultimodalLM.from_pretrained(G4_PATH, dtype=torch.bfloat16).to("cuda")
54
+ print("models loaded")
55
+
56
+
57
+ STOP_MARKERS = ("<turn|>", "<eos>", "<end_of_turn>", "<pad>")
58
+
59
+
60
+ def extract_answer(raw: str) -> str:
61
+ """DiffusionGemma's block looks like `<|channel>thought\\n<channel|>ANSWER<turn|>...`
62
+ even with thinking off — the answer is the text after the last `<channel|>`.
63
+ Gemma-4 emits plain text; we just cut at the first stop marker."""
64
+ stops = [i for m in STOP_MARKERS if (i := raw.find(m)) != -1]
65
+ if stops:
66
+ raw = raw[: min(stops)]
67
+ if "<channel|>" in raw:
68
+ raw = raw.rpartition("<channel|>")[2]
69
+ return raw.strip()
70
+
71
+
72
+ class SnapshotStreamer(TextDiffusionStreamer):
73
+ """Captures the decoded canvas at each denoising step; suppresses the
74
+ parent's ANSI console printing."""
75
+
76
+ def __init__(self, tokenizer):
77
+ super().__init__(tokenizer=tokenizer)
78
+ self.tok = tokenizer
79
+ self.snapshots: list[str] = []
80
+
81
+ def put_draft(self, value, **kwargs):
82
+ try:
83
+ ids = value[0] if value.ndim > 1 else value
84
+ self.snapshots.append(self.tok.decode(ids, skip_special_tokens=False))
85
+ except Exception:
86
+ pass
87
+
88
+ def put(self, value):
89
+ pass
90
+
91
+ def end(self):
92
+ pass
93
+
94
+
95
+ def _prepare_inputs(processor, model, ocr_text: str):
96
+ message = [{"role": "user", "content": PROMPT_TEMPLATE.format(ocr=ocr_text.strip())}]
97
+ return processor.apply_chat_template(
98
+ message, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt"
99
+ ).to(model.device)
100
+
101
+
102
+ def _decode_generated(processor, output, input_len) -> str:
103
+ # DiffusionGemma returns a DiffusionGemmaGenerationOutput whose .sequences
104
+ # includes the prompt (like AR generate, which returns a plain tensor).
105
+ seq = output.sequences if hasattr(output, "sequences") else output
106
+ generated = seq[0][input_len:] if seq.shape[-1] > input_len else seq[0]
107
+ raw = processor.tokenizer.decode(generated, skip_special_tokens=False)
108
+ return extract_answer(raw)
109
+
110
+
111
+ # size="xlarge" (96GB) on both: total module-level CUDA state is ~68GB bf16,
112
+ # which exceeds the default 48GB ZeroGPU slice.
113
+ @spaces.GPU(duration=120, size="xlarge")
114
+ def run_diffusiongemma(ocr_text: str):
115
+ inputs = _prepare_inputs(dg_processor, dg_model, ocr_text)
116
+ streamer = SnapshotStreamer(dg_processor.tokenizer)
117
+ t0 = time.perf_counter()
118
+ output = dg_model.generate(**inputs, max_new_tokens=256, streamer=streamer)
119
+ torch.cuda.synchronize()
120
+ seconds = time.perf_counter() - t0
121
+ text = _decode_generated(dg_processor, output, inputs["input_ids"].shape[-1])
122
+ n_tokens = len(dg_processor.tokenizer(text)["input_ids"])
123
+ timing = (
124
+ f"**{seconds:.1f}s** · ~{n_tokens / seconds:.0f} tok/s · "
125
+ f"{len(streamer.snapshots)} denoising steps"
126
+ )
127
+ return text, diff_highlight(ocr_text, text), timing, streamer.snapshots
128
+
129
+
130
+ @spaces.GPU(duration=60, size="xlarge")
131
+ def run_gemma4(ocr_text: str):
132
+ inputs = _prepare_inputs(g4_processor, g4_model, ocr_text)
133
+ t0 = time.perf_counter()
134
+ output = g4_model.generate(**inputs, max_new_tokens=256, do_sample=False)
135
+ torch.cuda.synchronize()
136
+ seconds = time.perf_counter() - t0
137
+ text = _decode_generated(g4_processor, output, inputs["input_ids"].shape[-1])
138
+ n_tokens = len(g4_processor.tokenizer(text)["input_ids"])
139
+ timing = f"**{seconds:.1f}s** · ~{n_tokens / seconds:.0f} tok/s (greedy)"
140
+ return text, diff_highlight(ocr_text, text), timing
141
+
142
+
143
+ # ---------------------------------------------------------------- UI data
144
+
145
+ examples: list[dict] = []
146
+ examples_path = Path("examples.json")
147
+ if examples_path.exists():
148
+ examples = json.loads(examples_path.read_text())
149
+ example_choices = {e["label"]: e["ocr_input"] for e in examples}
150
+
151
+ summary_md = "*Benchmark results pending — see the repo for methodology.*"
152
+ if Path("results/summary.md").exists():
153
+ summary_md = Path("results/summary.md").read_text()
154
+
155
+ per_passage_rows = []
156
+ if Path("results/per_passage_metrics.jsonl").exists():
157
+ per_passage_rows = [
158
+ json.loads(line)
159
+ for line in Path("results/per_passage_metrics.jsonl").read_text().splitlines()
160
+ if line.strip()
161
+ ]
162
+
163
+
164
+ def load_example(label: str) -> str:
165
+ return example_choices.get(label, "")
166
+
167
+
168
+ def check_length(text: str):
169
+ if len(text) > MAX_INPUT_CHARS:
170
+ raise gr.Error(
171
+ f"Input too long ({len(text)} chars). DiffusionGemma generates a single "
172
+ f"256-token block, so inputs are capped at ~{MAX_INPUT_CHARS} characters."
173
+ )
174
+ return text
175
+
176
+
177
+ def update_snapshot(snapshots: list[str], step: int) -> str:
178
+ if not snapshots:
179
+ return ""
180
+ return snapshots[min(int(step), len(snapshots) - 1)]
181
+
182
+
183
+ with gr.Blocks(title="DiffusionGemma vs Gemma-4: post-OCR correction") as demo:
184
+ gr.Markdown(
185
+ "# DiffusionGemma vs Gemma-4: post-OCR correction\n"
186
+ "Compare Google's **experimental diffusion LLM** "
187
+ "([google/diffusiongemma-26B-A4B-it](https://huggingface.co/google/diffusiongemma-26B-A4B-it), "
188
+ "26B MoE / 3.8B active, released 2026-06-10) against an autoregressive baseline "
189
+ "([google/gemma-4-E4B-it](https://huggingface.co/google/gemma-4-E4B-it)) on correcting "
190
+ "19th-century English newspaper OCR. Both run in bf16. Highlights show what each model "
191
+ "**changed relative to the OCR input** (yellow = changed, green = added, red ⌫ = deleted)."
192
+ )
193
+
194
+ with gr.Tab("Live comparison"):
195
+ with gr.Row():
196
+ example_dd = gr.Dropdown(
197
+ label="Example passages (ICDAR2019 post-OCR, CC-BY-4.0)",
198
+ choices=list(example_choices),
199
+ value=None,
200
+ scale=2,
201
+ )
202
+ ocr_box = gr.Textbox(
203
+ label="Noisy OCR text",
204
+ lines=6,
205
+ value=next(iter(example_choices.values()), ""),
206
+ max_length=MAX_INPUT_CHARS,
207
+ )
208
+ run_btn = gr.Button("Run both models", variant="primary")
209
+ with gr.Row():
210
+ with gr.Column():
211
+ gr.Markdown("### DiffusionGemma 26B-A4B (diffusion)")
212
+ dg_timing = gr.Markdown("")
213
+ dg_diff = gr.HighlightedText(
214
+ label="Output (diff vs input)", color_map=COLOR_MAP, combine_adjacent=True
215
+ )
216
+ with gr.Accordion("Raw output", open=False):
217
+ dg_raw = gr.Textbox(lines=6, show_label=False)
218
+ with gr.Column():
219
+ gr.Markdown("### Gemma-4-E4B (autoregressive)")
220
+ g4_timing = gr.Markdown("")
221
+ g4_diff = gr.HighlightedText(
222
+ label="Output (diff vs input)", color_map=COLOR_MAP, combine_adjacent=True
223
+ )
224
+ with gr.Accordion("Raw output", open=False):
225
+ g4_raw = gr.Textbox(lines=6, show_label=False)
226
+
227
+ snapshots_state = gr.State([])
228
+ example_dd.change(load_example, example_dd, ocr_box)
229
+ run_btn.click(check_length, ocr_box, ocr_box).success(
230
+ run_diffusiongemma, ocr_box, [dg_raw, dg_diff, dg_timing, snapshots_state]
231
+ ).then(run_gemma4, ocr_box, [g4_raw, g4_diff, g4_timing])
232
+
233
+ with gr.Tab("Denoising progression"):
234
+ gr.Markdown(
235
+ "DiffusionGemma starts from a random 256-token canvas and iteratively denoises it. "
236
+ "Run a comparison first, then scrub through the intermediate canvas states."
237
+ )
238
+ step_slider = gr.Slider(0, 47, step=1, value=0, label="Denoising step")
239
+ snapshot_box = gr.Textbox(lines=10, label="Canvas at step", interactive=False)
240
+ step_slider.change(update_snapshot, [snapshots_state, step_slider], snapshot_box)
241
+ snapshots_state.change(
242
+ lambda s: (gr.Slider(0, max(len(s) - 1, 1), step=1, value=0), update_snapshot(s, 0)),
243
+ snapshots_state,
244
+ [step_slider, snapshot_box],
245
+ )
246
+
247
+ with gr.Tab("Benchmark results"):
248
+ gr.Markdown(summary_md)
249
+ if per_passage_rows:
250
+ gr.Markdown("### Per-passage metrics (BLN600, n=75)")
251
+ gr.DataFrame(
252
+ value=[[row.get(k) for k in per_passage_rows[0]] for row in per_passage_rows],
253
+ headers=list(per_passage_rows[0]),
254
+ interactive=False,
255
+ )
256
+ gr.Markdown(
257
+ "Benchmark texts come from [BLN600](https://doi.org/10.15131/shef.data.25439023) "
258
+ "(CC-BY-NC-4.0), so passage texts are not redistributed here — only metrics. "
259
+ "See the Space README for methodology and limitations."
260
+ )
261
+
262
+ demo.launch()
diff_utils.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Token-level diff between OCR input and model output for gr.HighlightedText."""
2
+
3
+ import difflib
4
+ import re
5
+
6
+ COLOR_MAP = {"changed": "yellow", "added": "green", "removed": "red"}
7
+
8
+
9
+ def diff_highlight(input_text: str, output_text: str) -> list[tuple[str, str | None]]:
10
+ """Segments of `output_text` labelled by how they differ from `input_text`.
11
+
12
+ Word + whitespace tokenization (lossless), so highlights align with what
13
+ the reader sees. Deleted input text is marked with a small marker segment.
14
+ """
15
+ tokens_in = re.findall(r"\S+|\s+", input_text)
16
+ tokens_out = re.findall(r"\S+|\s+", output_text)
17
+ sm = difflib.SequenceMatcher(None, tokens_in, tokens_out, autojunk=False)
18
+ segments: list[tuple[str, str | None]] = []
19
+ for op, i1, i2, j1, j2 in sm.get_opcodes():
20
+ if op == "equal":
21
+ segments.append(("".join(tokens_out[j1:j2]), None))
22
+ elif op == "replace":
23
+ segments.append(("".join(tokens_out[j1:j2]), "changed"))
24
+ elif op == "insert":
25
+ segments.append(("".join(tokens_out[j1:j2]), "added"))
26
+ elif op == "delete":
27
+ segments.append((" ⌫ ", "removed"))
28
+ return segments
29
+
30
+
31
+ if __name__ == "__main__":
32
+ segs = diff_highlight("the qvick brown fox jumps", "the quick brown fox")
33
+ print(segs)
34
+ assert ("the ", None) in segs or segs[0][1] is None
35
+ assert any(label == "changed" for _, label in segs)
36
+ assert any(label == "removed" for _, label in segs)
37
+ out = "".join(s for s, label in segs if label != "removed")
38
+ assert out == "the quick brown fox"
39
+ print("diff_utils ok")
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ transformers>=5.11,<6
2
+ accelerate
3
+ pillow
4
+ torchvision