chitrark commited on
Commit
778e3eb
·
verified ·
1 Parent(s): 097f30f

updated ocr path

Browse files
Files changed (1) hide show
  1. app.py +56 -33
app.py CHANGED
@@ -10,23 +10,33 @@ from PIL import Image
10
  import gradio as gr
11
  from transformers import AutoProcessor, AutoModelForVision2Seq
12
 
13
- # Suppress ALL startup noise BEFORE any imports
 
 
14
  os.environ["OMP_NUM_THREADS"] = "1"
15
  os.environ["TRANSFORMERS_VERBOSITY"] = "error"
16
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
17
  warnings.filterwarnings("ignore")
18
 
 
 
 
19
  MODEL_ID = "allenai/olmOCR-2-7B-1025"
 
20
  processor = None
21
  model = None
22
 
23
 
24
  def load_model():
 
25
  global processor, model
26
  if processor is not None and model is not None:
27
  return
28
 
29
- processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
 
 
 
30
  model = AutoModelForVision2Seq.from_pretrained(
31
  MODEL_ID,
32
  dtype=torch.float16,
@@ -34,9 +44,13 @@ def load_model():
34
  low_cpu_mem_usage=True,
35
  trust_remote_code=True,
36
  ).eval()
37
- print("✅ Model loaded successfully!")
 
38
 
39
 
 
 
 
40
  def _resize_max_side(img: Image.Image, max_side: int = 896) -> Image.Image:
41
  w, h = img.size
42
  m = max(w, h)
@@ -52,16 +66,15 @@ def build_prompt(width: int, height: int) -> str:
52
  "Return ONLY the extracted text (no explanations, no markdown).\n"
53
  "Do not hallucinate.\n"
54
  "RAW_TEXT_START\n"
55
- f"Page dimensions: {width:.1f}x{height:.1f} [Image 0x0 to {width:.1f}x{height:.1f}]\n"
 
56
  "RAW_TEXT_END"
57
  )
58
 
59
 
60
  def _coerce_to_pil(img: Union[Image.Image, dict, str]) -> Image.Image:
61
  """
62
- Gradio UI often passes a PIL Image.
63
- gradio_client often passes a dict like {"path": "..."} or a string path.
64
- This function normalizes everything into a PIL Image.
65
  """
66
  if isinstance(img, Image.Image):
67
  return img
@@ -70,50 +83,54 @@ def _coerce_to_pil(img: Union[Image.Image, dict, str]) -> Image.Image:
70
  return Image.open(img)
71
 
72
  if isinstance(img, dict):
73
- # gradio_client image payload typically includes "path"
74
  path = img.get("path")
75
  if path:
76
  return Image.open(path)
77
 
78
- # sometimes it may include "url" (less common)
79
  url = img.get("url")
80
  if url and url.startswith("data:image"):
81
- header, b64 = url.split(",", 1)
82
- data = base64.b64decode(b64)
83
- return Image.open(BytesIO(data))
84
 
85
- raise ValueError(f"Unsupported image input type: {type(img)} / {img}")
86
 
87
 
 
 
 
88
  def ocr_image(img: Union[Image.Image, dict, str]) -> tuple[str, str]:
89
  if img is None:
90
  return "No image uploaded.", "0.0s"
91
 
92
- start_time = time.perf_counter()
93
  load_model()
94
 
95
- # ✅ Normalize input (fixes API calls crashing)
96
  try:
97
  img = _coerce_to_pil(img)
98
  except Exception as e:
99
- return f"Bad image input: {e}", "0.0s"
100
 
101
  img = img.convert("RGB")
102
- img = _resize_max_side(img, max_side=896)
103
  w, h = img.size
104
 
 
 
 
 
105
  buf = BytesIO()
106
  img.save(buf, format="PNG")
107
- image_b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
108
-
109
- prompt = build_prompt(w, h)
110
 
111
  messages = [
112
  {
113
  "role": "user",
114
  "content": [
115
  {"type": "text", "text": prompt},
116
- {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_b64}"}},
 
 
 
117
  ],
118
  }
119
  ]
@@ -131,8 +148,10 @@ def ocr_image(img: Union[Image.Image, dict, str]) -> tuple[str, str]:
131
  return_tensors="pt",
132
  )
133
 
134
- # Move inputs to model device
135
- inputs = {k: v.to(model.device) if torch.is_tensor(v) else v for k, v in inputs.items()}
 
 
136
 
137
  with torch.inference_mode():
138
  output_ids = model.generate(
@@ -143,19 +162,22 @@ def ocr_image(img: Union[Image.Image, dict, str]) -> tuple[str, str]:
143
 
144
  prompt_len = inputs["input_ids"].shape[1]
145
  gen_ids = output_ids[:, prompt_len:]
146
- text_out = processor.tokenizer.batch_decode(gen_ids, skip_special_tokens=True)
147
- result = text_out[0].strip() if text_out else "No text extracted."
 
148
 
149
- elapsed = time.perf_counter() - start_time
150
- timing = f"{elapsed:.2f}s"
151
- return result, timing
152
 
153
 
 
 
 
154
  with gr.Blocks(title="BookReader OCR API (olmOCR2)") as demo:
155
  gr.Markdown(
156
- "# BookReader OCR API (olmOCR2)\n"
157
- "Upload an image get extracted text + timing.\n\n"
158
- "**API endpoint:** `//ocr` (note the double slash)"
159
  )
160
 
161
  with gr.Row():
@@ -168,10 +190,11 @@ with gr.Blocks(title="BookReader OCR API (olmOCR2)") as demo:
168
 
169
  run_btn.click(
170
  fn=ocr_image,
171
- inputs=[image_input],
172
  outputs=[output, timing],
173
- api_name="/ocr", # ✅ match what your client discovered
174
  )
175
 
 
176
  if __name__ == "__main__":
177
  demo.queue().launch(show_error=True)
 
10
  import gradio as gr
11
  from transformers import AutoProcessor, AutoModelForVision2Seq
12
 
13
+ # -----------------------------------------------------------------------------
14
+ # Environment + warnings (quiet startup)
15
+ # -----------------------------------------------------------------------------
16
  os.environ["OMP_NUM_THREADS"] = "1"
17
  os.environ["TRANSFORMERS_VERBOSITY"] = "error"
18
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
19
  warnings.filterwarnings("ignore")
20
 
21
+ # -----------------------------------------------------------------------------
22
+ # Model config
23
+ # -----------------------------------------------------------------------------
24
  MODEL_ID = "allenai/olmOCR-2-7B-1025"
25
+
26
  processor = None
27
  model = None
28
 
29
 
30
  def load_model():
31
+ """Lazy-load model so Space boots fast."""
32
  global processor, model
33
  if processor is not None and model is not None:
34
  return
35
 
36
+ processor = AutoProcessor.from_pretrained(
37
+ MODEL_ID,
38
+ trust_remote_code=True,
39
+ )
40
  model = AutoModelForVision2Seq.from_pretrained(
41
  MODEL_ID,
42
  dtype=torch.float16,
 
44
  low_cpu_mem_usage=True,
45
  trust_remote_code=True,
46
  ).eval()
47
+
48
+ print("✅ olmOCR-2 model loaded")
49
 
50
 
51
+ # -----------------------------------------------------------------------------
52
+ # Helpers
53
+ # -----------------------------------------------------------------------------
54
  def _resize_max_side(img: Image.Image, max_side: int = 896) -> Image.Image:
55
  w, h = img.size
56
  m = max(w, h)
 
66
  "Return ONLY the extracted text (no explanations, no markdown).\n"
67
  "Do not hallucinate.\n"
68
  "RAW_TEXT_START\n"
69
+ f"Page dimensions: {width:.1f}x{height:.1f} "
70
+ f"[Image 0x0 to {width:.1f}x{height:.1f}]\n"
71
  "RAW_TEXT_END"
72
  )
73
 
74
 
75
  def _coerce_to_pil(img: Union[Image.Image, dict, str]) -> Image.Image:
76
  """
77
+ Normalize Gradio UI input and gradio_client input into a PIL Image.
 
 
78
  """
79
  if isinstance(img, Image.Image):
80
  return img
 
83
  return Image.open(img)
84
 
85
  if isinstance(img, dict):
 
86
  path = img.get("path")
87
  if path:
88
  return Image.open(path)
89
 
 
90
  url = img.get("url")
91
  if url and url.startswith("data:image"):
92
+ _, b64 = url.split(",", 1)
93
+ return Image.open(BytesIO(base64.b64decode(b64)))
 
94
 
95
+ raise ValueError(f"Unsupported image input: {type(img)}")
96
 
97
 
98
+ # -----------------------------------------------------------------------------
99
+ # OCR function (API)
100
+ # -----------------------------------------------------------------------------
101
  def ocr_image(img: Union[Image.Image, dict, str]) -> tuple[str, str]:
102
  if img is None:
103
  return "No image uploaded.", "0.0s"
104
 
105
+ start = time.perf_counter()
106
  load_model()
107
 
 
108
  try:
109
  img = _coerce_to_pil(img)
110
  except Exception as e:
111
+ return f"Invalid image input: {e}", "0.0s"
112
 
113
  img = img.convert("RGB")
114
+ img = _resize_max_side(img)
115
  w, h = img.size
116
 
117
+ # Build prompt
118
+ prompt = build_prompt(w, h)
119
+
120
+ # Encode image for VLM message
121
  buf = BytesIO()
122
  img.save(buf, format="PNG")
123
+ image_b64 = base64.b64encode(buf.getvalue()).decode()
 
 
124
 
125
  messages = [
126
  {
127
  "role": "user",
128
  "content": [
129
  {"type": "text", "text": prompt},
130
+ {
131
+ "type": "image_url",
132
+ "image_url": {"url": f"data:image/png;base64,{image_b64}"},
133
+ },
134
  ],
135
  }
136
  ]
 
148
  return_tensors="pt",
149
  )
150
 
151
+ inputs = {
152
+ k: v.to(model.device) if torch.is_tensor(v) else v
153
+ for k, v in inputs.items()
154
+ }
155
 
156
  with torch.inference_mode():
157
  output_ids = model.generate(
 
162
 
163
  prompt_len = inputs["input_ids"].shape[1]
164
  gen_ids = output_ids[:, prompt_len:]
165
+ text = processor.tokenizer.batch_decode(
166
+ gen_ids, skip_special_tokens=True
167
+ )
168
 
169
+ elapsed = time.perf_counter() - start
170
+ return (text[0].strip() if text else "No text extracted.", f"{elapsed:.2f}s")
 
171
 
172
 
173
+ # -----------------------------------------------------------------------------
174
+ # Gradio UI + API
175
+ # -----------------------------------------------------------------------------
176
  with gr.Blocks(title="BookReader OCR API (olmOCR2)") as demo:
177
  gr.Markdown(
178
+ "# 📖 BookReader OCR API (olmOCR2)\n"
179
+ "Upload an image and extract text using **olmOCR-2-7B**.\n\n"
180
+ "**API endpoint:** `/ocr`"
181
  )
182
 
183
  with gr.Row():
 
190
 
191
  run_btn.click(
192
  fn=ocr_image,
193
+ inputs=image_input,
194
  outputs=[output, timing],
195
+ api_name="/ocr",
196
  )
197
 
198
+
199
  if __name__ == "__main__":
200
  demo.queue().launch(show_error=True)