himipo commited on
Commit
036aa17
·
verified ·
1 Parent(s): efa70a6

Deploy OCR Model Workbench

Browse files
Files changed (3) hide show
  1. app.py +10 -4
  2. ocr_workbench/client.py +15 -2
  3. ocr_workbench/registry.py +14 -5
app.py CHANGED
@@ -61,13 +61,16 @@ def _model_info(spec: ModelSpec) -> str:
61
  )
62
 
63
 
64
- def on_model_change(model_id: str) -> tuple[str, str, str, int, bool, str]:
65
  spec = _spec(model_id)
66
  return (
67
  spec.default_prompt,
68
  spec.endpoint(),
69
  _model_info(spec),
70
  spec.default_max_tokens,
 
 
 
71
  spec.default_layout_as_thought,
72
  spec.default_image_mode,
73
  )
@@ -221,10 +224,10 @@ with gr.Blocks(title="OCR Model Workbench") as demo:
221
  minimum=1,
222
  maximum=ABSOLUTE_MAX_PAGES,
223
  step=1,
224
- value=min(DEFAULT_MAX_PAGES, ABSOLUTE_MAX_PAGES),
225
  label="最大ページ数",
226
  )
227
- dpi = gr.Slider(96, 300, value=180, step=12, label="PDF rasterize DPI")
228
  max_new_tokens = gr.Slider(
229
  256,
230
  32768,
@@ -244,7 +247,7 @@ with gr.Blocks(title="OCR Model Workbench") as demo:
244
  request_timeout = gr.Slider(
245
  60,
246
  1800,
247
- value=600,
248
  step=30,
249
  label="1ページのタイムアウト(秒)",
250
  )
@@ -293,6 +296,9 @@ with gr.Blocks(title="OCR Model Workbench") as demo:
293
  endpoint_override,
294
  model_info,
295
  max_new_tokens,
 
 
 
296
  layout_as_thought,
297
  unlimited_image_mode,
298
  ],
 
61
  )
62
 
63
 
64
+ def on_model_change(model_id: str) -> tuple[str, str, str, int, int, int, int, bool, str]:
65
  spec = _spec(model_id)
66
  return (
67
  spec.default_prompt,
68
  spec.endpoint(),
69
  _model_info(spec),
70
  spec.default_max_tokens,
71
+ min(spec.default_max_pages, ABSOLUTE_MAX_PAGES),
72
+ spec.default_dpi,
73
+ spec.default_request_timeout,
74
  spec.default_layout_as_thought,
75
  spec.default_image_mode,
76
  )
 
224
  minimum=1,
225
  maximum=ABSOLUTE_MAX_PAGES,
226
  step=1,
227
+ value=min(default_spec.default_max_pages, DEFAULT_MAX_PAGES, ABSOLUTE_MAX_PAGES),
228
  label="最大ページ数",
229
  )
230
+ dpi = gr.Slider(96, 300, value=default_spec.default_dpi, step=12, label="PDF rasterize DPI")
231
  max_new_tokens = gr.Slider(
232
  256,
233
  32768,
 
247
  request_timeout = gr.Slider(
248
  60,
249
  1800,
250
+ value=default_spec.default_request_timeout,
251
  step=30,
252
  label="1ページのタイムアウト(秒)",
253
  )
 
296
  endpoint_override,
297
  model_info,
298
  max_new_tokens,
299
+ max_pages,
300
+ dpi,
301
+ request_timeout,
302
  layout_as_thought,
303
  unlimited_image_mode,
304
  ],
ocr_workbench/client.py CHANGED
@@ -4,6 +4,7 @@ import json
4
  import os
5
  import inspect
6
  import time
 
7
  from pathlib import Path
8
  from typing import Any
9
 
@@ -43,10 +44,13 @@ def check_health(endpoint: str, timeout_seconds: float = 20.0) -> dict[str, Any]
43
  if not endpoint:
44
  raise WorkerError("Worker URL is not configured.")
45
  try:
46
- payload = _client(endpoint).predict(api_name="/health")
 
47
  if not isinstance(payload, dict):
48
  raise WorkerError("Worker health response is not a JSON object.")
49
  return payload
 
 
50
  except Exception as exc:
51
  raise WorkerError(f"Worker health check failed: {exc}") from exc
52
 
@@ -71,7 +75,7 @@ def run_page(
71
  if delay:
72
  time.sleep(delay)
73
  try:
74
- payload = _client(endpoint).predict(
75
  handle_file(str(page_path)),
76
  model_id,
77
  prompt or "",
@@ -79,6 +83,7 @@ def run_page(
79
  os.getenv("WORKER_API_TOKEN", "").strip(),
80
  api_name="/ocr",
81
  )
 
82
  if not isinstance(payload, dict):
83
  raise WorkerError("Worker response is not a JSON object.")
84
  required = {"model", "text", "markdown", "metrics"}
@@ -86,6 +91,14 @@ def run_page(
86
  if missing:
87
  raise WorkerError(f"Worker response is missing fields: {missing}")
88
  return payload
 
 
 
 
 
 
 
 
89
  except (ValueError, OSError, WorkerError, Exception) as exc:
90
  last_error = exc
91
  if attempt >= len(retry_delays):
 
4
  import os
5
  import inspect
6
  import time
7
+ from concurrent.futures import TimeoutError
8
  from pathlib import Path
9
  from typing import Any
10
 
 
44
  if not endpoint:
45
  raise WorkerError("Worker URL is not configured.")
46
  try:
47
+ job = _client(endpoint).submit(api_name="/health")
48
+ payload = job.result(timeout=timeout_seconds)
49
  if not isinstance(payload, dict):
50
  raise WorkerError("Worker health response is not a JSON object.")
51
  return payload
52
+ except TimeoutError as exc:
53
+ raise WorkerError(f"Worker health check timed out after {timeout_seconds:.0f}s.") from exc
54
  except Exception as exc:
55
  raise WorkerError(f"Worker health check failed: {exc}") from exc
56
 
 
75
  if delay:
76
  time.sleep(delay)
77
  try:
78
+ job = _client(endpoint).submit(
79
  handle_file(str(page_path)),
80
  model_id,
81
  prompt or "",
 
83
  os.getenv("WORKER_API_TOKEN", "").strip(),
84
  api_name="/ocr",
85
  )
86
+ payload = job.result(timeout=timeout_seconds)
87
  if not isinstance(payload, dict):
88
  raise WorkerError("Worker response is not a JSON object.")
89
  required = {"model", "text", "markdown", "metrics"}
 
91
  if missing:
92
  raise WorkerError(f"Worker response is missing fields: {missing}")
93
  return payload
94
+ except TimeoutError as exc:
95
+ last_error = WorkerError(f"Worker request timed out after {timeout_seconds:.0f}s.")
96
+ try:
97
+ job.cancel()
98
+ except Exception:
99
+ pass
100
+ if attempt >= len(retry_delays):
101
+ break
102
  except (ValueError, OSError, WorkerError, Exception) as exc:
103
  last_error = exc
104
  if attempt >= len(retry_delays):
ocr_workbench/registry.py CHANGED
@@ -15,6 +15,9 @@ class ModelSpec:
15
  description: str
16
  result_note: str
17
  default_max_tokens: int = 4096
 
 
 
18
  default_layout_as_thought: bool = False
19
  default_image_mode: str = "gundam"
20
 
@@ -32,13 +35,16 @@ _BUILTIN_MODELS: list[ModelSpec] = [
32
  id="paddleocr-vl-1.6",
33
  label="PaddleOCR-VL 1.6",
34
  endpoint_env="PADDLEOCR_VL_WORKER_URL",
35
- default_prompt="Parse this document to Markdown.",
36
  description=(
37
- "Compact document parser for text, layout, tables, formulas, charts and seals. "
38
- "The Storage Bucket can be mounted on this worker as a persistent model/cache volume."
39
  ),
40
- result_note="Returns PaddleOCR export images plus Markdown/JSON where available.",
41
- default_max_tokens=4096,
 
 
 
42
  ),
43
  ModelSpec(
44
  id="qianfan-ocr",
@@ -105,6 +111,9 @@ def _custom_models() -> list[ModelSpec]:
105
  description=str(item.get("description", "Custom OCR worker.")),
106
  result_note=str(item.get("result_note", "Uses the common OCR worker response schema.")),
107
  default_max_tokens=int(item.get("default_max_tokens", 4096)),
 
 
 
108
  default_layout_as_thought=bool(item.get("default_layout_as_thought", False)),
109
  default_image_mode=str(item.get("default_image_mode", "gundam")),
110
  )
 
15
  description: str
16
  result_note: str
17
  default_max_tokens: int = 4096
18
+ default_max_pages: int = 8
19
+ default_dpi: int = 180
20
+ default_request_timeout: int = 600
21
  default_layout_as_thought: bool = False
22
  default_image_mode: str = "gundam"
23
 
 
35
  id="paddleocr-vl-1.6",
36
  label="PaddleOCR-VL 1.6",
37
  endpoint_env="PADDLEOCR_VL_WORKER_URL",
38
+ default_prompt="OCR:",
39
  description=(
40
+ "Compact OCR/document VLM. This worker uses the Transformers/PyTorch backend so "
41
+ "inference runs inside ZeroGPU."
42
  ),
43
+ result_note="Returns model text plus the normalized input image in ZeroGPU mode.",
44
+ default_max_tokens=512,
45
+ default_max_pages=1,
46
+ default_dpi=120,
47
+ default_request_timeout=900,
48
  ),
49
  ModelSpec(
50
  id="qianfan-ocr",
 
111
  description=str(item.get("description", "Custom OCR worker.")),
112
  result_note=str(item.get("result_note", "Uses the common OCR worker response schema.")),
113
  default_max_tokens=int(item.get("default_max_tokens", 4096)),
114
+ default_max_pages=int(item.get("default_max_pages", 8)),
115
+ default_dpi=int(item.get("default_dpi", 180)),
116
+ default_request_timeout=int(item.get("default_request_timeout", 600)),
117
  default_layout_as_thought=bool(item.get("default_layout_as_thought", False)),
118
  default_image_mode=str(item.get("default_image_mode", "gundam")),
119
  )