Spaces:
Running on Zero
Running on Zero
| """DiffusionGemma vs Gemma-4 on post-OCR correction — ZeroGPU comparison Space. | |
| gradio.Server pattern: custom HTML frontend (index.html) + Gradio queuing | |
| backend. Side-by-side correction of 19th-century English newspaper OCR by an | |
| experimental block-diffusion LLM (google/diffusiongemma-26B-A4B-it) and an | |
| autoregressive baseline (google/gemma-4-E4B-it). | |
| """ | |
| import difflib | |
| import json | |
| import os | |
| import re | |
| import time | |
| from pathlib import Path | |
| import spaces | |
| import torch | |
| from fastapi.responses import HTMLResponse, JSONResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from gradio import Server | |
| from transformers import ( | |
| AutoModelForMultimodalLM, | |
| AutoProcessor, | |
| DiffusionGemmaForBlockDiffusion, | |
| TextDiffusionStreamer, | |
| ) | |
| HERE = Path(__file__).resolve().parent | |
| # Keep in sync with benchmark.py PROMPT_TEMPLATE — the benchmark numbers in the | |
| # results tab were produced with exactly this prompt. | |
| PROMPT_TEMPLATE = """\ | |
| Correct the OCR errors in the following text from a 19th-century English newspaper. | |
| Fix only recognition errors (wrong, missing, or extra characters). Do not modernise \ | |
| spelling, do not rephrase, and do not add or remove content. Preserve the original \ | |
| punctuation unless it is clearly an OCR error. | |
| Output only the corrected text, with no commentary or preamble. | |
| OCR text: | |
| {ocr}""" | |
| MAX_INPUT_CHARS = 1200 # roughly the 220-token benchmark cap | |
| STOP_MARKERS = ("<turn|>", "<eos>", "<end_of_turn>", "<pad>") | |
| def model_path(volume_path: str, model_id: str) -> str: | |
| """Prefer a mounted hf:// volume unless USE_VOLUMES=0 (FUSE reads can be | |
| slower for safetensors loading than a fresh download to local disk).""" | |
| if os.environ.get("USE_VOLUMES", "1") == "0": | |
| return model_id | |
| return volume_path if os.path.isdir(volume_path) else model_id | |
| DG_PATH = model_path("/models/dg", "google/diffusiongemma-26B-A4B-it") | |
| G4_PATH = model_path("/models/gemma", "google/gemma-4-E4B-it") | |
| t0 = time.perf_counter() | |
| print(f"loading DiffusionGemma from {DG_PATH} ...") | |
| dg_processor = AutoProcessor.from_pretrained(DG_PATH) | |
| dg_model = DiffusionGemmaForBlockDiffusion.from_pretrained(DG_PATH, dtype=torch.bfloat16).to("cuda") | |
| print(f"DiffusionGemma loaded in {time.perf_counter() - t0:.0f}s") | |
| t0 = time.perf_counter() | |
| print(f"loading Gemma-4 from {G4_PATH} ...") | |
| g4_processor = AutoProcessor.from_pretrained(G4_PATH) | |
| g4_model = AutoModelForMultimodalLM.from_pretrained(G4_PATH, dtype=torch.bfloat16).to("cuda") | |
| print(f"Gemma-4 loaded in {time.perf_counter() - t0:.0f}s") | |
| # ---------------------------------------------------------------- text utils | |
| def extract_answer(raw: str) -> str: | |
| """DiffusionGemma's block looks like `<|channel>thought\\n<channel|>ANSWER<turn|>...` | |
| even with thinking off — the answer is the text after the last `<channel|>`. | |
| Gemma-4 emits plain text; we just cut at the first stop marker.""" | |
| stops = [i for m in STOP_MARKERS if (i := raw.find(m)) != -1] | |
| if stops: | |
| raw = raw[: min(stops)] | |
| if "<channel|>" in raw: | |
| raw = raw.rpartition("<channel|>")[2] | |
| return raw.strip() | |
| def diff_segments(input_text: str, output_text: str) -> list[dict]: | |
| """Word+whitespace diff of output vs input -> [{text, op}] segments, | |
| op in {same, changed, added, removed}. Rendered by the frontend.""" | |
| tokens_in = re.findall(r"\S+|\s+", input_text) | |
| tokens_out = re.findall(r"\S+|\s+", output_text) | |
| sm = difflib.SequenceMatcher(None, tokens_in, tokens_out, autojunk=False) | |
| segments = [] | |
| for op, i1, i2, j1, j2 in sm.get_opcodes(): | |
| if op == "equal": | |
| segments.append({"text": "".join(tokens_out[j1:j2]), "op": "same"}) | |
| elif op == "replace": | |
| segments.append({"text": "".join(tokens_out[j1:j2]), "op": "changed"}) | |
| elif op == "insert": | |
| segments.append({"text": "".join(tokens_out[j1:j2]), "op": "added"}) | |
| elif op == "delete": | |
| segments.append({"text": "".join(tokens_in[i1:i2]), "op": "removed"}) | |
| return segments | |
| class SnapshotStreamer(TextDiffusionStreamer): | |
| """Captures the decoded canvas at each denoising step; suppresses the | |
| parent's ANSI console printing.""" | |
| def __init__(self, tokenizer): | |
| super().__init__(tokenizer=tokenizer) | |
| self.tok = tokenizer | |
| self.snapshots: list[str] = [] | |
| def put_draft(self, value, **kwargs): | |
| try: | |
| ids = value[0] if value.ndim > 1 else value | |
| self.snapshots.append(self.tok.decode(ids, skip_special_tokens=False)) | |
| except Exception: | |
| pass | |
| def put(self, value): | |
| pass | |
| def end(self): | |
| pass | |
| def _prepare_inputs(processor, model, ocr_text: str): | |
| message = [{"role": "user", "content": PROMPT_TEMPLATE.format(ocr=ocr_text.strip())}] | |
| return processor.apply_chat_template( | |
| message, tokenize=True, add_generation_prompt=True, return_dict=True, return_tensors="pt" | |
| ).to(model.device) | |
| def _decode_generated(processor, output, input_len) -> str: | |
| # DiffusionGemma returns a DiffusionGemmaGenerationOutput whose .sequences | |
| # includes the prompt (like AR generate, which returns a plain tensor). | |
| seq = output.sequences if hasattr(output, "sequences") else output | |
| generated = seq[0][input_len:] if seq.shape[-1] > input_len else seq[0] | |
| raw = processor.tokenizer.decode(generated, skip_special_tokens=False) | |
| return extract_answer(raw) | |
| def _validate(ocr_text: str) -> str | None: | |
| if not ocr_text or not ocr_text.strip(): | |
| return "Empty input." | |
| if len(ocr_text) > MAX_INPUT_CHARS: | |
| return ( | |
| f"Input too long ({len(ocr_text)} chars). DiffusionGemma generates a single " | |
| f"256-token block, so inputs are capped at ~{MAX_INPUT_CHARS} characters." | |
| ) | |
| return None | |
| # ---------------------------------------------------------------- API | |
| app = Server() | |
| def run_diffusiongemma(ocr_text: str, canvas_init: bool = False, gold: str = "") -> dict: | |
| """Correct OCR text with DiffusionGemma. canvas_init=True seeds the first | |
| denoising canvas with the OCR text itself (experimental — under-corrects; | |
| see the results tab) instead of random noise. If a gold transcription is | |
| supplied (demo examples), a diff against it is returned too.""" | |
| if err := _validate(ocr_text): | |
| return {"error": err} | |
| inputs = _prepare_inputs(dg_processor, dg_model, ocr_text) | |
| streamer = SnapshotStreamer(dg_processor.tokenizer) | |
| gen_kwargs: dict = {"max_new_tokens": 256, "streamer": streamer} | |
| if canvas_init: | |
| canvas_length = getattr(dg_model.generation_config, "canvas_length", None) or 256 | |
| ids = dg_processor.tokenizer(ocr_text, add_special_tokens=False)["input_ids"] | |
| ids = ids[:canvas_length] | |
| vocab = dg_model.config.text_config.vocab_size | |
| pad = torch.randint(vocab, (canvas_length - len(ids),)) | |
| canvas = torch.cat([torch.tensor(ids, dtype=torch.long), pad]) | |
| gen_kwargs["decoder_input_ids"] = canvas.unsqueeze(0).to(dg_model.device) | |
| t0 = time.perf_counter() | |
| output = dg_model.generate(**inputs, **gen_kwargs) | |
| torch.cuda.synchronize() | |
| seconds = time.perf_counter() - t0 | |
| text = _decode_generated(dg_processor, output, inputs["input_ids"].shape[-1]) | |
| n_tokens = len(dg_processor.tokenizer(text)["input_ids"]) | |
| return { | |
| "text": text, | |
| "diff": diff_segments(ocr_text.strip(), text), | |
| "diff_gold": diff_segments(gold.strip(), text) if gold.strip() else None, | |
| "seconds": round(seconds, 2), | |
| "tokens_per_second": round(n_tokens / seconds, 1), | |
| "denoising_steps": len(streamer.snapshots), | |
| "snapshots": [extract_answer(s) for s in streamer.snapshots], | |
| "canvas_init": canvas_init, | |
| "error": None, | |
| } | |
| def run_gemma4(ocr_text: str, gold: str = "") -> dict: | |
| """Correct OCR text with the autoregressive Gemma-4-E4B baseline (greedy).""" | |
| if err := _validate(ocr_text): | |
| return {"error": err} | |
| inputs = _prepare_inputs(g4_processor, g4_model, ocr_text) | |
| t0 = time.perf_counter() | |
| output = g4_model.generate(**inputs, max_new_tokens=256, do_sample=False) | |
| torch.cuda.synchronize() | |
| seconds = time.perf_counter() - t0 | |
| text = _decode_generated(g4_processor, output, inputs["input_ids"].shape[-1]) | |
| n_tokens = len(g4_processor.tokenizer(text)["input_ids"]) | |
| return { | |
| "text": text, | |
| "diff": diff_segments(ocr_text.strip(), text), | |
| "diff_gold": diff_segments(gold.strip(), text) if gold.strip() else None, | |
| "seconds": round(seconds, 2), | |
| "tokens_per_second": round(n_tokens / seconds, 1), | |
| "error": None, | |
| } | |
| # ---------------------------------------------------------------- static data | |
| async def homepage(): | |
| return (HERE / "index.html").read_text(encoding="utf-8") | |
| async def get_examples(): | |
| examples = json.loads((HERE / "examples.json").read_text()) | |
| cached, golds = {}, {} | |
| cached_path = HERE / "examples_cached.json" | |
| if cached_path.exists(): | |
| for e in json.loads(cached_path.read_text()): | |
| for m, out in e["output"].items(): | |
| out.pop("_raw", None) | |
| cached[e["id"]] = e["output"] | |
| golds[e["id"]] = e.get("gold", "") | |
| for e in examples: | |
| e["cached"] = cached.get(e["id"]) | |
| e["gold"] = golds.get(e["id"], "") | |
| if e["cached"]: | |
| for m, out in e["cached"].items(): | |
| out["diff"] = diff_segments(e["ocr_input"].strip(), out["text"]) | |
| if e["gold"]: | |
| out["diff_gold"] = diff_segments(e["gold"].strip(), out["text"]) | |
| return JSONResponse(examples) | |
| async def get_results(): | |
| summary = (HERE / "results" / "summary.md").read_text() | |
| rows = [ | |
| json.loads(line) | |
| for line in (HERE / "results" / "per_passage_metrics.jsonl").read_text().splitlines() | |
| if line.strip() | |
| ] | |
| return JSONResponse({"summary_md": summary, "per_passage": rows}) | |
| _images_dir = HERE / "images" | |
| if _images_dir.is_dir(): | |
| app.mount("/static", StaticFiles(directory=str(_images_dir)), name="static") | |
| if __name__ == "__main__": | |
| app.launch(show_error=True) | |