chitrark commited on
Commit
c70f99d
·
verified ·
1 Parent(s): c8aa7db

updated with different change for OOM issue

Browse files
Files changed (1) hide show
  1. app.py +59 -51
app.py CHANGED
@@ -5,61 +5,73 @@ import torch
5
  from PIL import Image
6
  import gradio as gr
7
 
8
- from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
9
-
10
- # ----- Model & processor -----
11
 
 
12
  MODEL_NAME = "allenai/olmOCR-2-7B-1025"
13
- PROCESSOR_NAME = "Qwen/Qwen2.5-VL-7B-Instruct"
 
 
 
 
14
 
15
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
- print("Loading model on", device)
17
 
18
- model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
19
- MODEL_NAME,
20
- dtype=torch.bfloat16, # use bfloat16 on GPU; change to torch.float16 if needed
21
- ).to(device).eval()
 
22
 
23
- processor = AutoProcessor.from_pretrained(PROCESSOR_NAME)
24
 
 
 
 
 
 
 
 
25
 
26
- # ----- OCR logic (image -> text) -----
27
 
28
  def build_image_prompt(width: int, height: int) -> str:
29
  """
30
- Simple document-style prompt for a single image page.
31
- You can tweak wording; this keeps it close to olmOCR's 'document' framing.
32
  """
33
  return (
34
- "You are an OCR engine. Read the document page shown in the image and "
35
- "return the plain text exactly as it appears, in natural reading order. "
36
- "Do not add extra commentary or formatting.\n"
37
  "RAW_TEXT_START\n"
38
  f"Page dimensions: {width:.1f}x{height:.1f} [Image 0x0 to {width:.1f}x{height:.1f}]\n"
39
  "RAW_TEXT_END"
40
  )
41
 
42
 
43
- def ocr_image(image: Image.Image):
44
- if image is None:
 
 
 
 
 
 
 
 
 
 
45
  return "No image uploaded."
46
 
47
- # Ensure RGB
48
- img = image.convert("RGB")
49
 
50
- # Resize to keep longest side <= 1024 for efficiency
51
- max_side = 1024
52
  w, h = img.size
53
- scale = min(max_side / max(w, h), 1.0)
54
- if scale < 1.0:
55
- img = img.resize((int(w * scale), int(h * scale)), Image.LANCZOS)
56
- w, h = img.size
57
 
58
- # Encode to base64 to match usual VLM 'image_url' style usage
59
  buf = BytesIO()
60
  img.save(buf, format="PNG")
61
- image_bytes = buf.getvalue()
62
- image_b64 = base64.b64encode(image_bytes).decode("utf-8")
63
 
64
  prompt = build_image_prompt(w, h)
65
 
@@ -76,45 +88,42 @@ def ocr_image(image: Image.Image):
76
  }
77
  ]
78
 
79
- # Apply chat template and preprocess
80
- text = processor.apply_chat_template(
81
  messages,
82
  tokenize=False,
83
  add_generation_prompt=True,
84
  )
85
 
86
  inputs = processor(
87
- text=[text],
88
  images=[img],
89
  padding=True,
90
  return_tensors="pt",
91
  )
92
- inputs = {k: v.to(device) for k, v in inputs.items()}
 
93
 
94
- with torch.no_grad():
95
- output = model.generate(
96
  **inputs,
97
- temperature=0.6,
98
  max_new_tokens=512,
99
- num_return_sequences=1,
100
- do_sample=True,
101
  )
102
 
 
103
  prompt_len = inputs["input_ids"].shape[1]
104
- new_tokens = output[:, prompt_len:]
105
- text_output = processor.tokenizer.batch_decode(
106
- new_tokens, skip_special_tokens=True
107
- )
108
-
109
- return text_output[0].strip() if text_output else "No text extracted."
110
 
 
 
111
 
112
- # ----- Gradio UI -----
113
 
114
- with gr.Blocks(title="olmOCR‑2 Image OCR") as demo:
115
  gr.Markdown(
116
- "# olmOCR‑2 Image OCR\n"
117
- "Upload an image and get extracted text using the olmOCR‑2‑7B model."
 
118
  )
119
 
120
  image_input = gr.Image(type="pil", label="Upload image")
@@ -127,5 +136,4 @@ with gr.Blocks(title="olmOCR‑2 Image OCR") as demo:
127
  api_name="/ocr",
128
  )
129
 
130
- if __name__ == "__main__":
131
- demo.queue().launch()
 
5
  from PIL import Image
6
  import gradio as gr
7
 
8
+ from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
 
 
9
 
10
+ # Model + processor
11
  MODEL_NAME = "allenai/olmOCR-2-7B-1025"
12
+ PROCESSOR_NAME = "Qwen/Qwen2-VL-7B-Instruct"
13
+
14
+ # Lazy-loaded globals (so Space boots faster)
15
+ processor = None
16
+ model = None
17
 
 
 
18
 
19
+ def load_model():
20
+ """Load processor + model once. Use device_map='auto' to fit on T4."""
21
+ global processor, model
22
+ if processor is not None and model is not None:
23
+ return
24
 
25
+ processor = AutoProcessor.from_pretrained(PROCESSOR_NAME)
26
 
27
+ # T4 is happiest with fp16; device_map="auto" avoids full VRAM load.
28
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
29
+ MODEL_NAME,
30
+ torch_dtype=torch.float16,
31
+ device_map="auto",
32
+ low_cpu_mem_usage=True,
33
+ ).eval()
34
 
 
35
 
36
  def build_image_prompt(width: int, height: int) -> str:
37
  """
38
+ Simple document-anchored OCR prompt.
39
+ Keep it short to reduce prompt tokens + hallucination risk.
40
  """
41
  return (
42
+ "Extract all readable text from this page image.\n"
43
+ "Return ONLY the extracted text (no explanations, no markdown).\n"
44
+ "Do not hallucinate.\n"
45
  "RAW_TEXT_START\n"
46
  f"Page dimensions: {width:.1f}x{height:.1f} [Image 0x0 to {width:.1f}x{height:.1f}]\n"
47
  "RAW_TEXT_END"
48
  )
49
 
50
 
51
+ def _resize_max_side(img: Image.Image, max_side: int = 896) -> Image.Image:
52
+ """Resize to keep inference stable on T4."""
53
+ w, h = img.size
54
+ m = max(w, h)
55
+ if m <= max_side:
56
+ return img
57
+ scale = max_side / m
58
+ return img.resize((int(w * scale), int(h * scale)), Image.LANCZOS)
59
+
60
+
61
+ def ocr_image(img: Image.Image):
62
+ if img is None:
63
  return "No image uploaded."
64
 
65
+ load_model()
 
66
 
67
+ img = img.convert("RGB")
68
+ img = _resize_max_side(img, max_side=896)
69
  w, h = img.size
 
 
 
 
70
 
71
+ # Encode to base64 for image_url-style messages
72
  buf = BytesIO()
73
  img.save(buf, format="PNG")
74
+ image_b64 = base64.b64encode(buf.getvalue()).decode("utf-8")
 
75
 
76
  prompt = build_image_prompt(w, h)
77
 
 
88
  }
89
  ]
90
 
91
+ # Build chat text
92
+ chat_text = processor.apply_chat_template(
93
  messages,
94
  tokenize=False,
95
  add_generation_prompt=True,
96
  )
97
 
98
  inputs = processor(
99
+ text=[chat_text],
100
  images=[img],
101
  padding=True,
102
  return_tensors="pt",
103
  )
104
+ # NOTE: DO NOT .to("cuda") here when using device_map="auto"
105
+ # transformers will handle placement.
106
 
107
+ with torch.inference_mode():
108
+ output_ids = model.generate(
109
  **inputs,
 
110
  max_new_tokens=512,
111
+ do_sample=False, # OCR should be deterministic
 
112
  )
113
 
114
+ # Remove the prompt tokens to keep only the generated part
115
  prompt_len = inputs["input_ids"].shape[1]
116
+ gen_ids = output_ids[:, prompt_len:]
 
 
 
 
 
117
 
118
+ text_out = processor.tokenizer.batch_decode(gen_ids, skip_special_tokens=True)
119
+ return text_out[0].strip() if text_out else "No text extracted."
120
 
 
121
 
122
+ with gr.Blocks(title="BookReader OCR API (olmOCR2)") as demo:
123
  gr.Markdown(
124
+ "# BookReader OCR API (olmOCR2)\n"
125
+ "Upload an image get extracted text.\n\n"
126
+ "**API endpoint:** `/ocr`"
127
  )
128
 
129
  image_input = gr.Image(type="pil", label="Upload image")
 
136
  api_name="/ocr",
137
  )
138
 
139
+ demo.queue().launch()