chitrark commited on
Commit
a82af97
·
verified ·
1 Parent(s): ce86bd1

refactor: replace broken olmocr convert_files with VLM-based OCR pipeline

Browse files
Files changed (1) hide show
  1. app.py +107 -36
app.py CHANGED
@@ -1,57 +1,128 @@
1
- import tempfile
2
- import gradio as gr
3
 
4
- # Updated import: convert_files is exposed from the top-level package
5
- from olmocr.runner import convert_files # replace with the real module name
 
6
 
 
7
 
 
8
  MODEL_NAME = "allenai/olmOCR-2-7B-1025"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
 
11
- def ocr(file_obj):
12
  if file_obj is None:
13
- return "No file uploaded."
14
-
15
- in_path = file_obj.name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- with tempfile.TemporaryDirectory() as tmpdir:
18
- results = convert_files(
19
- inputs=[in_path],
20
- output_dir=tmpdir,
21
- model_name=MODEL_NAME,
 
 
 
 
 
 
 
 
 
 
 
22
  )
23
 
24
- if not results:
25
- return "No output."
26
-
27
- r0 = results[0]
28
-
29
- # Try direct text
30
- text = getattr(r0, "text", None)
31
-
32
- # Fallback: read from output file
33
- if not text:
34
- out_path = getattr(r0, "output_path", None)
35
- if out_path:
36
- with open(out_path, "r", encoding="utf-8") as f:
37
- text = f.read()
38
 
39
- return (text or "No text extracted.").strip()
40
 
41
 
42
- with gr.Blocks(title="BookReader OCR API (olmOCR2)") as demo:
43
  gr.Markdown(
44
- "# BookReader OCR API (olmOCR2)\n"
45
- "Upload an image or PDF → get extracted text.\n\n"
46
- "**API endpoint:** `/ocr`"
47
  )
48
 
49
- upload = gr.File(label="Upload PDF or image", file_count="single")
50
- output = gr.Textbox(label="Extracted text", lines=18)
51
 
52
  gr.Button("Run OCR").click(
53
- fn=ocr,
54
- inputs=[upload],
55
  outputs=[output],
56
  api_name="/ocr",
57
  )
 
1
+ import base64
2
+ from io import BytesIO
3
 
4
+ import torch
5
+ from PIL import Image
6
+ import gradio as gr
7
 
8
+ from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
9
 
10
+ # Vision-language model used by olmOCR-2
11
  MODEL_NAME = "allenai/olmOCR-2-7B-1025"
12
+ PROCESSOR_NAME = "Qwen/Qwen2-VL-7B-Instruct"
13
+
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+
16
+ print("Loading model on", device)
17
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
18
+ MODEL_NAME,
19
+ torch_dtype=torch.bfloat16,
20
+ ).to(device).eval()
21
+
22
+ processor = AutoProcessor.from_pretrained(PROCESSOR_NAME)
23
+
24
+
25
+ def build_image_prompt(width: int, height: int) -> str:
26
+ """
27
+ Minimal 'document anchoring' style prompt for a single image on a page.
28
+ This follows the structure described in olmOCR docs/blogs:
29
+ page dimensions + image box, then 'RAW_TEXT_START/END'.
30
+ """
31
+ prompt = (
32
+ "Below is the image of one page of a document, as well as some raw textual "
33
+ "content that was previously extracted for it. "
34
+ "Just return the plain text representation of this document as if you "
35
+ "were reading it naturally. Do not hallucinate.\n"
36
+ "RAW_TEXT_START\n"
37
+ f"Page dimensions: {width:.1f}x{height:.1f} [Image 0x0 to {width:.1f}x{height:.1f}]\n"
38
+ "RAW_TEXT_END"
39
+ )
40
+ return prompt
41
 
42
 
43
+ def ocr_image(file_obj: gr.File):
44
  if file_obj is None:
45
+ return "No image uploaded."
46
+
47
+ # Load the uploaded image
48
+ img = Image.open(file_obj).convert("RGB")
49
+
50
+ # Optionally resize to keep max side around 1024 for performance/quality
51
+ max_side = 1024
52
+ w, h = img.size
53
+ scale = min(max_side / max(w, h), 1.0)
54
+ if scale < 1.0:
55
+ img = img.resize((int(w * scale), int(h * scale)), Image.LANCZOS)
56
+ w, h = img.size
57
+
58
+ # Encode to base64 (needed to match typical VLM 'image_url' usage)
59
+ buf = BytesIO()
60
+ img.save(buf, format="PNG")
61
+ image_bytes = buf.getvalue()
62
+ image_b64 = base64.b64encode(image_bytes).decode("utf-8")
63
+
64
+ # Build prompt for this image
65
+ prompt = build_image_prompt(w, h)
66
+
67
+ messages = [
68
+ {
69
+ "role": "user",
70
+ "content": [
71
+ {"type": "text", "text": prompt},
72
+ {
73
+ "type": "image_url",
74
+ "image_url": {"url": f"data:image/png;base64,{image_b64}"},
75
+ },
76
+ ],
77
+ }
78
+ ]
79
+
80
+ # Apply chat template and preprocess
81
+ text = processor.apply_chat_template(
82
+ messages,
83
+ tokenize=False,
84
+ add_generation_prompt=True,
85
+ )
86
 
87
+ inputs = processor(
88
+ text=[text],
89
+ images=[img],
90
+ padding=True,
91
+ return_tensors="pt",
92
+ )
93
+ inputs = {k: v.to(device) for k, v in inputs.items()}
94
+
95
+ # Generate output
96
+ with torch.no_grad():
97
+ output = model.generate(
98
+ **inputs,
99
+ temperature=0.6,
100
+ max_new_tokens=512,
101
+ num_return_sequences=1,
102
+ do_sample=True,
103
  )
104
 
105
+ prompt_len = inputs["input_ids"].shape[1]
106
+ new_tokens = output[:, prompt_len:]
107
+ text_output = processor.tokenizer.batch_decode(
108
+ new_tokens, skip_special_tokens=True
109
+ )
 
 
 
 
 
 
 
 
 
110
 
111
+ return text_output[0].strip() if text_output else "No text extracted."
112
 
113
 
114
+ with gr.Blocks(title="olmOCR‑2 Image OCR") as demo:
115
  gr.Markdown(
116
+ "# olmOCR‑2 Image OCR\n"
117
+ "Upload an image and get extracted text using the olmOCR‑2‑7B model."
 
118
  )
119
 
120
+ image_input = gr.Image(type="pil", label="Upload image")
121
+ output = gr.Textbox(label="Extracted text", lines=20)
122
 
123
  gr.Button("Run OCR").click(
124
+ fn=ocr_image,
125
+ inputs=[image_input],
126
  outputs=[output],
127
  api_name="/ocr",
128
  )