chitrark commited on
Commit
b2fc952
·
verified ·
1 Parent(s): 6ba0575

fixed few HF warnings

Browse files
Files changed (1) hide show
  1. app.py +21 -19
app.py CHANGED
@@ -2,20 +2,20 @@ import os
2
  import base64
3
  from io import BytesIO
4
  import warnings
 
5
 
6
  import torch
7
  from PIL import Image
8
  import gradio as gr
9
  from transformers import AutoProcessor, AutoModelForVision2Seq
10
 
11
- # Suppress warnings at startup
12
  os.environ["OMP_NUM_THREADS"] = "1"
13
  os.environ["TRANSFORMERS_VERBOSITY"] = "error"
14
- warnings.filterwarnings("ignore", category=FutureWarning)
 
15
 
16
- # IMPORTANT: Load processor+model from the olmOCR checkpoint itself
17
  MODEL_ID = "allenai/olmOCR-2-7B-1025"
18
-
19
  processor = None
20
  model = None
21
 
@@ -25,17 +25,15 @@ def load_model():
25
  if processor is not None and model is not None:
26
  return
27
 
28
- # trust_remote_code is often required for VLM checkpoints
29
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
30
-
31
- # T4: use fp16 + device_map auto to avoid OOM
32
  model = AutoModelForVision2Seq.from_pretrained(
33
  MODEL_ID,
34
- torch_dtype=torch.float16,
35
  device_map="auto",
36
  low_cpu_mem_usage=True,
37
  trust_remote_code=True,
38
  ).eval()
 
39
 
40
 
41
  def _resize_max_side(img: Image.Image, max_side: int = 896) -> Image.Image:
@@ -48,7 +46,6 @@ def _resize_max_side(img: Image.Image, max_side: int = 896) -> Image.Image:
48
 
49
 
50
  def build_prompt(width: int, height: int) -> str:
51
- # Keep it short + strict to reduce hallucinations
52
  return (
53
  "Extract all readable text from this page image.\n"
54
  "Return ONLY the extracted text (no explanations, no markdown).\n"
@@ -59,17 +56,17 @@ def build_prompt(width: int, height: int) -> str:
59
  )
60
 
61
 
62
- def ocr_image(img: Image.Image) -> str:
63
  if img is None:
64
- return "No image uploaded."
65
 
 
66
  load_model()
67
 
68
  img = img.convert("RGB")
69
  img = _resize_max_side(img, max_side=896)
70
  w, h = img.size
71
 
72
- # Base64 image_url message (works with many Qwen-style chat templates)
73
  buf = BytesIO()
74
  img.save(buf, format="PNG")
75
  image_b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
@@ -99,27 +96,31 @@ def ocr_image(img: Image.Image) -> str:
99
  return_tensors="pt",
100
  )
101
 
102
- # FIX: Move inputs to model's device (eliminates the warning)
103
  inputs = {k: v.to(model.device) if torch.is_tensor(v) else v for k, v in inputs.items()}
104
 
105
  with torch.inference_mode():
106
  output_ids = model.generate(
107
  **inputs,
108
  max_new_tokens=512,
109
- do_sample=False, # deterministic OCR
110
  )
111
 
112
- # Remove prompt tokens, keep only generated text
113
  prompt_len = inputs["input_ids"].shape[1]
114
  gen_ids = output_ids[:, prompt_len:]
115
  text_out = processor.tokenizer.batch_decode(gen_ids, skip_special_tokens=True)
116
- return text_out[0].strip() if text_out else "No text extracted."
 
 
 
 
 
117
 
118
 
119
  with gr.Blocks(title="BookReader OCR API (olmOCR2)") as demo:
120
  gr.Markdown(
121
  "# BookReader OCR API (olmOCR2)\n"
122
- "Upload an image → get extracted text.\n\n"
123
  "**API endpoint:** `/ocr`"
124
  )
125
 
@@ -128,12 +129,13 @@ with gr.Blocks(title="BookReader OCR API (olmOCR2)") as demo:
128
  image_input = gr.Image(type="pil", label="Upload image")
129
  run_btn = gr.Button("Run OCR", variant="primary")
130
  with gr.Column():
131
- output = gr.Textbox(label="Extracted text", lines=20)
 
132
 
133
  run_btn.click(
134
  fn=ocr_image,
135
  inputs=[image_input],
136
- outputs=[output],
137
  api_name="/ocr",
138
  )
139
 
 
2
  import base64
3
  from io import BytesIO
4
  import warnings
5
+ import time # For timing
6
 
7
  import torch
8
  from PIL import Image
9
  import gradio as gr
10
  from transformers import AutoProcessor, AutoModelForVision2Seq
11
 
12
+ # Suppress ALL startup noise BEFORE any imports
13
  os.environ["OMP_NUM_THREADS"] = "1"
14
  os.environ["TRANSFORMERS_VERBOSITY"] = "error"
15
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
16
+ warnings.filterwarnings("ignore")
17
 
 
18
  MODEL_ID = "allenai/olmOCR-2-7B-1025"
 
19
  processor = None
20
  model = None
21
 
 
25
  if processor is not None and model is not None:
26
  return
27
 
 
28
  processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
 
 
29
  model = AutoModelForVision2Seq.from_pretrained(
30
  MODEL_ID,
31
+ dtype=torch.float16, # Fixed deprecation
32
  device_map="auto",
33
  low_cpu_mem_usage=True,
34
  trust_remote_code=True,
35
  ).eval()
36
+ print("✅ Model loaded successfully!")
37
 
38
 
39
  def _resize_max_side(img: Image.Image, max_side: int = 896) -> Image.Image:
 
46
 
47
 
48
  def build_prompt(width: int, height: int) -> str:
 
49
  return (
50
  "Extract all readable text from this page image.\n"
51
  "Return ONLY the extracted text (no explanations, no markdown).\n"
 
56
  )
57
 
58
 
59
+ def ocr_image(img: Image.Image) -> tuple[str, str]:
60
  if img is None:
61
+ return "No image uploaded.", "0.0s"
62
 
63
+ start_time = time.perf_counter() # High-precision timer
64
  load_model()
65
 
66
  img = img.convert("RGB")
67
  img = _resize_max_side(img, max_side=896)
68
  w, h = img.size
69
 
 
70
  buf = BytesIO()
71
  img.save(buf, format="PNG")
72
  image_b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
 
96
  return_tensors="pt",
97
  )
98
 
99
+ # Move inputs to model device
100
  inputs = {k: v.to(model.device) if torch.is_tensor(v) else v for k, v in inputs.items()}
101
 
102
  with torch.inference_mode():
103
  output_ids = model.generate(
104
  **inputs,
105
  max_new_tokens=512,
106
+ do_sample=False,
107
  )
108
 
 
109
  prompt_len = inputs["input_ids"].shape[1]
110
  gen_ids = output_ids[:, prompt_len:]
111
  text_out = processor.tokenizer.batch_decode(gen_ids, skip_special_tokens=True)
112
+ result = text_out[0].strip() if text_out else "No text extracted."
113
+
114
+ elapsed = time.perf_counter() - start_time
115
+ timing = f"{elapsed:.2f}s"
116
+
117
+ return result, timing
118
 
119
 
120
  with gr.Blocks(title="BookReader OCR API (olmOCR2)") as demo:
121
  gr.Markdown(
122
  "# BookReader OCR API (olmOCR2)\n"
123
+ "Upload an image → get extracted text + timing.\n\n"
124
  "**API endpoint:** `/ocr`"
125
  )
126
 
 
129
  image_input = gr.Image(type="pil", label="Upload image")
130
  run_btn = gr.Button("Run OCR", variant="primary")
131
  with gr.Column():
132
+ output = gr.Textbox(label="Extracted text", lines=15)
133
+ timing = gr.Textbox(label="Generation time", interactive=False)
134
 
135
  run_btn.click(
136
  fn=ocr_image,
137
  inputs=[image_input],
138
+ outputs=[output, timing],
139
  api_name="/ocr",
140
  )
141