chitrark commited on
Commit
097f30f
·
verified ·
1 Parent(s): ffb2e43

updated to fix some issues

Browse files
Files changed (1) hide show
  1. app.py +42 -9
app.py CHANGED
@@ -2,7 +2,8 @@ import os
2
  import base64
3
  from io import BytesIO
4
  import warnings
5
- import time # For timing
 
6
 
7
  import torch
8
  from PIL import Image
@@ -28,7 +29,7 @@ def load_model():
28
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
29
  model = AutoModelForVision2Seq.from_pretrained(
30
  MODEL_ID,
31
- dtype=torch.float16, # Fixed deprecation
32
  device_map="auto",
33
  low_cpu_mem_usage=True,
34
  trust_remote_code=True,
@@ -56,13 +57,47 @@ def build_prompt(width: int, height: int) -> str:
56
  )
57
 
58
 
59
- def ocr_image(img: Image.Image) -> tuple[str, str]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  if img is None:
61
  return "No image uploaded.", "0.0s"
62
 
63
- start_time = time.perf_counter() # High-precision timer
64
  load_model()
65
 
 
 
 
 
 
 
66
  img = img.convert("RGB")
67
  img = _resize_max_side(img, max_side=896)
68
  w, h = img.size
@@ -95,7 +130,7 @@ def ocr_image(img: Image.Image) -> tuple[str, str]:
95
  padding=True,
96
  return_tensors="pt",
97
  )
98
-
99
  # Move inputs to model device
100
  inputs = {k: v.to(model.device) if torch.is_tensor(v) else v for k, v in inputs.items()}
101
 
@@ -113,7 +148,6 @@ def ocr_image(img: Image.Image) -> tuple[str, str]:
113
 
114
  elapsed = time.perf_counter() - start_time
115
  timing = f"{elapsed:.2f}s"
116
-
117
  return result, timing
118
 
119
 
@@ -121,7 +155,7 @@ with gr.Blocks(title="BookReader OCR API (olmOCR2)") as demo:
121
  gr.Markdown(
122
  "# BookReader OCR API (olmOCR2)\n"
123
  "Upload an image → get extracted text + timing.\n\n"
124
- "**API endpoint:** `/ocr`"
125
  )
126
 
127
  with gr.Row():
@@ -136,9 +170,8 @@ with gr.Blocks(title="BookReader OCR API (olmOCR2)") as demo:
136
  fn=ocr_image,
137
  inputs=[image_input],
138
  outputs=[output, timing],
139
- api_name="/ocr",
140
  )
141
 
142
  if __name__ == "__main__":
143
  demo.queue().launch(show_error=True)
144
-
 
2
  import base64
3
  from io import BytesIO
4
  import warnings
5
+ import time
6
+ from typing import Union
7
 
8
  import torch
9
  from PIL import Image
 
29
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
30
  model = AutoModelForVision2Seq.from_pretrained(
31
  MODEL_ID,
32
+ dtype=torch.float16,
33
  device_map="auto",
34
  low_cpu_mem_usage=True,
35
  trust_remote_code=True,
 
57
  )
58
 
59
 
60
+ def _coerce_to_pil(img: Union[Image.Image, dict, str]) -> Image.Image:
61
+ """
62
+ Gradio UI often passes a PIL Image.
63
+ gradio_client often passes a dict like {"path": "..."} or a string path.
64
+ This function normalizes everything into a PIL Image.
65
+ """
66
+ if isinstance(img, Image.Image):
67
+ return img
68
+
69
+ if isinstance(img, str):
70
+ return Image.open(img)
71
+
72
+ if isinstance(img, dict):
73
+ # gradio_client image payload typically includes "path"
74
+ path = img.get("path")
75
+ if path:
76
+ return Image.open(path)
77
+
78
+ # sometimes it may include "url" (less common)
79
+ url = img.get("url")
80
+ if url and url.startswith("data:image"):
81
+ header, b64 = url.split(",", 1)
82
+ data = base64.b64decode(b64)
83
+ return Image.open(BytesIO(data))
84
+
85
+ raise ValueError(f"Unsupported image input type: {type(img)} / {img}")
86
+
87
+
88
+ def ocr_image(img: Union[Image.Image, dict, str]) -> tuple[str, str]:
89
  if img is None:
90
  return "No image uploaded.", "0.0s"
91
 
92
+ start_time = time.perf_counter()
93
  load_model()
94
 
95
+ # ✅ Normalize input (fixes API calls crashing)
96
+ try:
97
+ img = _coerce_to_pil(img)
98
+ except Exception as e:
99
+ return f"Bad image input: {e}", "0.0s"
100
+
101
  img = img.convert("RGB")
102
  img = _resize_max_side(img, max_side=896)
103
  w, h = img.size
 
130
  padding=True,
131
  return_tensors="pt",
132
  )
133
+
134
  # Move inputs to model device
135
  inputs = {k: v.to(model.device) if torch.is_tensor(v) else v for k, v in inputs.items()}
136
 
 
148
 
149
  elapsed = time.perf_counter() - start_time
150
  timing = f"{elapsed:.2f}s"
 
151
  return result, timing
152
 
153
 
 
155
  gr.Markdown(
156
  "# BookReader OCR API (olmOCR2)\n"
157
  "Upload an image → get extracted text + timing.\n\n"
158
+ "**API endpoint:** `//ocr` (note the double slash)"
159
  )
160
 
161
  with gr.Row():
 
170
  fn=ocr_image,
171
  inputs=[image_input],
172
  outputs=[output, timing],
173
+ api_name="/ocr", # ✅ match what your client discovered
174
  )
175
 
176
  if __name__ == "__main__":
177
  demo.queue().launch(show_error=True)