airzy1 commited on
Commit
b9bf728
·
verified ·
1 Parent(s): 09738fb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -210
app.py CHANGED
@@ -1,15 +1,18 @@
1
-
2
- import json
3
  import os
 
4
  import re
5
- from typing import Any, Dict, List, Optional, Tuple
6
 
 
7
  import gradio as gr
8
  import spaces
9
- import torch
10
- from PIL import Image, ImageDraw, ImageFilter, ImageFont, ImageOps
11
- from qwen_vl_utils import process_vision_info
12
- from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
 
 
 
13
 
14
  # ---------------------------
15
  # Environment / cache setup
@@ -29,18 +32,8 @@ torch.set_float32_matmul_precision("high")
29
 
30
  HF_TOKEN = os.environ.get("HF_TOKEN", "")
31
 
32
- # Strong 7B-class OCR/VLM choice with official benchmark evidence.
33
- # You can swap this later if you decide to test Qwen3-VL on a newer stack.
34
- MODEL_ID = "Qwen/Qwen2.5-VL-7B-Instruct"
35
-
36
- # Visual token budget: high enough for label reading, but not absurd for ZeroGPU.
37
- # Official docs show min_pixels/max_pixels as the supported way to control resolution.
38
- MIN_PIXELS = 256 * 28 * 28
39
- MAX_PIXELS = 2048 * 28 * 28
40
-
41
- # Image prep knobs.
42
- FULL_LONG_SIDE = 2200
43
- TILE_LONG_SIDE = 1600
44
 
45
  processor = None
46
  model = None
@@ -55,136 +48,25 @@ def load_model() -> None:
55
  processor = AutoProcessor.from_pretrained(
56
  MODEL_ID,
57
  token=HF_TOKEN if HF_TOKEN else None,
58
- min_pixels=MIN_PIXELS,
59
- max_pixels=MAX_PIXELS,
60
  )
61
 
62
  print("Loading model...")
63
- model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
64
  MODEL_ID,
65
  token=HF_TOKEN if HF_TOKEN else None,
66
  device_map="auto",
67
- torch_dtype="auto",
68
  low_cpu_mem_usage=True,
69
  )
70
 
 
71
  model.eval()
72
  print("Model ready")
73
 
74
 
75
- def _resize_long_side(image: Image.Image, target_long_side: int) -> Image.Image:
76
- """Resize only when needed, preserving aspect ratio."""
77
- long_side = max(image.size)
78
- if long_side <= target_long_side:
79
- return image
80
-
81
- scale = target_long_side / long_side
82
- new_size = (
83
- max(1, int(round(image.width * scale))),
84
- max(1, int(round(image.height * scale))),
85
- )
86
- return image.resize(new_size, Image.Resampling.LANCZOS)
87
-
88
-
89
- def prepare_image(image: Image.Image, target_long_side: int = FULL_LONG_SIDE) -> Image.Image:
90
- """Upscale/sharpen for tiny pantry text and ingredient panels."""
91
- image = ImageOps.exif_transpose(image).convert("RGB")
92
- image = _resize_long_side(image, target_long_side)
93
- image = ImageOps.autocontrast(image)
94
- image = image.filter(ImageFilter.SHARPEN)
95
- image = image.filter(ImageFilter.DETAIL)
96
- return image
97
-
98
-
99
- def crop_with_padding(
100
- image: Image.Image,
101
- box: Tuple[int, int, int, int],
102
- pad_frac: float = 0.06,
103
- target_long_side: int = TILE_LONG_SIDE,
104
- ) -> Image.Image:
105
- """Crop a region with some padding, then upscale it for OCR."""
106
- w, h = image.size
107
- x0, y0, x1, y1 = box
108
- pad_x = int(round((x1 - x0) * pad_frac))
109
- pad_y = int(round((y1 - y0) * pad_frac))
110
-
111
- x0 = max(0, x0 - pad_x)
112
- y0 = max(0, y0 - pad_y)
113
- x1 = min(w, x1 + pad_x)
114
- y1 = min(h, y1 + pad_y)
115
-
116
- crop = image.crop((x0, y0, x1, y1))
117
- crop = prepare_image(crop, target_long_side=target_long_side)
118
- return crop
119
-
120
-
121
- def build_panels(image: Image.Image) -> List[Tuple[str, Image.Image]]:
122
- """Create a small set of zoom panels to help the VLM read tiny labels."""
123
- image = prepare_image(image, target_long_side=FULL_LONG_SIDE)
124
- w, h = image.size
125
- panels: List[Tuple[str, Image.Image]] = [("full", image)]
126
-
127
- # For larger pantry shots, quadrants usually capture labels better than one huge scene.
128
- if max(w, h) >= 1200:
129
- mid_x = w // 2
130
- mid_y = h // 2
131
- overlap_x = int(round(w * 0.10))
132
- overlap_y = int(round(h * 0.10))
133
-
134
- quads = {
135
- "top_left": (0, 0, mid_x + overlap_x, mid_y + overlap_y),
136
- "top_right": (mid_x - overlap_x, 0, w, mid_y + overlap_y),
137
- "bottom_left": (0, mid_y - overlap_y, mid_x + overlap_x, h),
138
- "bottom_right": (mid_x - overlap_x, mid_y - overlap_y, w, h),
139
- }
140
- for label, box in quads.items():
141
- panels.append((label, crop_with_padding(image, box, pad_frac=0.05)))
142
- else:
143
- # For smaller images, a centered zoom is often more useful than tiling.
144
- cx0 = int(w * 0.15)
145
- cy0 = int(h * 0.15)
146
- cx1 = int(w * 0.85)
147
- cy1 = int(h * 0.85)
148
- if cx1 > cx0 and cy1 > cy0:
149
- panels.append(("center_zoom", crop_with_padding(image, (cx0, cy0, cx1, cy1), pad_frac=0.03)))
150
-
151
- return panels[:5]
152
-
153
-
154
- def make_contact_sheet(panels: List[Tuple[str, Image.Image]]) -> Image.Image:
155
- """Build a single preview image so the user can see what the model saw."""
156
- cols = 2
157
- tile_w = 720
158
- tile_h = 520
159
- gap = 16
160
- label_h = 28
161
-
162
- rows = (len(panels) + cols - 1) // cols
163
- sheet_w = cols * tile_w + (cols + 1) * gap
164
- sheet_h = rows * (tile_h + label_h) + (rows + 1) * gap
165
-
166
- canvas = Image.new("RGB", (sheet_w, sheet_h), (245, 245, 245))
167
- draw = ImageDraw.Draw(canvas)
168
- font = ImageFont.load_default()
169
-
170
- for idx, (label, img) in enumerate(panels):
171
- row = idx // cols
172
- col = idx % cols
173
- x = gap + col * (tile_w + gap)
174
- y = gap + row * (tile_h + label_h + gap)
175
-
176
- tile = ImageOps.contain(img, (tile_w, tile_h))
177
- tile_bg = Image.new("RGB", (tile_w, tile_h), (255, 255, 255))
178
- offset = ((tile_w - tile.width) // 2, (tile_h - tile.height) // 2)
179
- tile_bg.paste(tile, offset)
180
- canvas.paste(tile_bg, (x, y + label_h))
181
-
182
- draw.rectangle([x, y, x + tile_w, y + label_h], fill=(230, 230, 230))
183
- draw.text((x + 8, y + 6), label, fill=(20, 20, 20), font=font)
184
-
185
- draw.rectangle([x, y + label_h, x + tile_w, y + label_h + tile_h], outline=(200, 200, 200), width=1)
186
-
187
- return canvas
188
 
189
 
190
  def extract_json(text: str) -> Dict[str, Any]:
@@ -210,95 +92,62 @@ def extract_json(text: str) -> Dict[str, Any]:
210
  return {"raw_output": text}
211
 
212
 
213
- PROMPT = """
214
- Analyze all provided panels from the same pantry photo.
215
-
216
- Goal:
217
- - Read visible brand names, product names, ingredients, and tiny printed text.
218
- - Use the full panel and the zoom panels together.
219
- - Do not guess. If text is unreadable, say "unreadable".
220
- - Merge duplicates across panels.
221
- - Prefer exact visible text over paraphrase.
222
-
223
- Return strict JSON only with this shape:
224
- {
225
- "items": [
226
- {
227
- "brand": "",
228
- "product_name": "",
229
- "visible_text": "",
230
- "ingredients": [""],
231
- "tiny_text_quality": "clear|partial|unreadable",
232
- "confidence": 0.0,
233
- "evidence_panels": ["full", "top_left", "top_right", "bottom_left", "bottom_right", "center_zoom"]
234
- }
235
- ],
236
- "warnings": [""],
237
- "notes": ""
238
- }
239
- """.strip()
240
-
241
-
242
- @spaces.GPU(size="large", duration=90)
243
- def analyze_pantry(image: Image.Image) -> Tuple[Optional[Image.Image], Dict[str, Any]]:
244
  if image is None:
245
  return None, {"error": "Upload an image first."}
246
 
247
  load_model()
248
 
249
- panels = build_panels(image)
250
- contact_sheet = make_contact_sheet(panels)
251
 
252
- # Qwen chat format: the model receives multiple images plus one instruction block.
253
  messages = [
254
  {
255
  "role": "system",
256
  "content": [
257
- {
258
- "type": "text",
259
- "text": "You are a careful OCR and pantry-label extraction assistant. Return valid JSON only.",
260
- }
261
  ],
262
  },
263
  {
264
  "role": "user",
265
  "content": [
266
- {
267
- "type": "text",
268
- "text": (
269
- "Panel order: full, top_left, top_right, bottom_left, bottom_right, center_zoom. "
270
- f"{PROMPT}"
271
- ),
272
- },
273
- *[{"type": "image", "image": panel_img} for _, panel_img in panels],
274
  ],
275
  },
276
  ]
277
 
278
- text = processor.apply_chat_template(
 
279
  messages,
280
- tokenize=False,
281
  add_generation_prompt=True,
282
- )
283
-
284
- image_inputs, video_inputs = process_vision_info(messages)
285
-
286
- inputs = processor(
287
- text=[text],
288
- images=image_inputs,
289
- videos=video_inputs,
290
- padding=True,
291
  return_tensors="pt",
292
  )
293
 
294
- # Some model/processor versions include token_type_ids, some do not.
295
- inputs.pop("token_type_ids", None)
296
  inputs = inputs.to(model.device)
297
 
298
  with torch.inference_mode():
299
  output_ids = model.generate(
300
  **inputs,
301
- max_new_tokens=700,
302
  do_sample=False,
303
  )
304
 
@@ -313,24 +162,13 @@ def analyze_pantry(image: Image.Image) -> Tuple[Optional[Image.Image], Dict[str,
313
  if isinstance(parsed, dict) and "raw_output" not in parsed:
314
  parsed["_raw_output"] = generated_text
315
 
316
- return contact_sheet, parsed
317
-
318
-
319
- # Simple helper tests for local sanity checks.
320
- def _self_test() -> None:
321
- blank = Image.new("RGB", (900, 700), "white")
322
- panels = build_panels(blank)
323
- assert len(panels) >= 2
324
- sheet = make_contact_sheet(panels)
325
- assert sheet.size[0] > 0 and sheet.size[1] > 0
326
- assert extract_json('{"a": 1}') == {"a": 1}
327
- assert "raw_output" in extract_json("not json")
328
 
329
 
330
  with gr.Blocks() as demo:
331
  gr.Markdown("# Pantry Scanner")
332
  gr.Markdown(
333
- "Upload a pantry photo. The app sends the full image plus zoomed panels to help the model read tiny labels and ingredients."
334
  )
335
 
336
  with gr.Row():
@@ -340,7 +178,7 @@ with gr.Blocks() as demo:
340
  analyze_btn = gr.Button("Analyze", variant="primary")
341
 
342
  with gr.Row():
343
- prepared_output = gr.Image(type="pil", label="Panels sent to the model")
344
  output_json = gr.JSON(label="Detected items")
345
 
346
  analyze_btn.click(
 
 
 
1
  import os
2
+ import json
3
  import re
4
+ from typing import Any, Dict, Tuple
5
 
6
+ import torch
7
  import gradio as gr
8
  import spaces
9
+
10
+ from PIL import Image, ImageOps
11
+
12
+ # Qwen3-VL requires the latest Transformers from source.
13
+ # In your Space requirements, use:
14
+ # pip install git+https://github.com/huggingface/transformers
15
+ from transformers import AutoProcessor, Qwen3VLForConditionalGeneration
16
 
17
  # ---------------------------
18
  # Environment / cache setup
 
32
 
33
  HF_TOKEN = os.environ.get("HF_TOKEN", "")
34
 
35
+ # Qwen3-VL upgrade path
36
+ MODEL_ID = "Qwen/Qwen3-VL-8B-Instruct"
 
 
 
 
 
 
 
 
 
 
37
 
38
  processor = None
39
  model = None
 
48
  processor = AutoProcessor.from_pretrained(
49
  MODEL_ID,
50
  token=HF_TOKEN if HF_TOKEN else None,
 
 
51
  )
52
 
53
  print("Loading model...")
54
+ model = Qwen3VLForConditionalGeneration.from_pretrained(
55
  MODEL_ID,
56
  token=HF_TOKEN if HF_TOKEN else None,
57
  device_map="auto",
58
+ torch_dtype=torch.bfloat16,
59
  low_cpu_mem_usage=True,
60
  )
61
 
62
+ print("Setting eval mode...")
63
  model.eval()
64
  print("Model ready")
65
 
66
 
67
+ def normalize_image(image: Image.Image) -> Image.Image:
68
+ """Keep the original image path simple: no cropping, no tiling, no enhancement."""
69
+ return ImageOps.exif_transpose(image).convert("RGB")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
 
72
  def extract_json(text: str) -> Dict[str, Any]:
 
92
  return {"raw_output": text}
93
 
94
 
95
+ PROMPT = (
96
+ "Inspect this single pantry image and return only JSON. "
97
+ "Identify the visible brand name, product name, ingredients, and any other clearly readable package text. "
98
+ "Do not guess tiny text you cannot read. "
99
+ "Use this schema: {"
100
+ '"brand": string|null, '
101
+ '"product_name": string|null, '
102
+ '"ingredients": [string], '
103
+ '"visible_text": [string], '
104
+ '"packaging_notes": string|null, '
105
+ '"confidence": {"brand": number, "product_name": number, "ingredients": number}, '
106
+ '"raw_ocr": [string]'
107
+ "}."
108
+ )
109
+
110
+
111
+ @spaces.GPU(size="large", duration=60)
112
+ def analyze_pantry(image: Image.Image) -> Tuple[Image.Image, Dict[str, Any]]:
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  if image is None:
114
  return None, {"error": "Upload an image first."}
115
 
116
  load_model()
117
 
118
+ prepared = normalize_image(image)
 
119
 
 
120
  messages = [
121
  {
122
  "role": "system",
123
  "content": [
124
+ {"type": "text", "text": "You are a precise visual OCR assistant. Return JSON only."}
 
 
 
125
  ],
126
  },
127
  {
128
  "role": "user",
129
  "content": [
130
+ {"type": "image", "image": prepared},
131
+ {"type": "text", "text": PROMPT},
 
 
 
 
 
 
132
  ],
133
  },
134
  ]
135
 
136
+ # Qwen3-VL official Transformers usage.
137
+ inputs = processor.apply_chat_template(
138
  messages,
139
+ tokenize=True,
140
  add_generation_prompt=True,
141
+ return_dict=True,
 
 
 
 
 
 
 
 
142
  return_tensors="pt",
143
  )
144
 
 
 
145
  inputs = inputs.to(model.device)
146
 
147
  with torch.inference_mode():
148
  output_ids = model.generate(
149
  **inputs,
150
+ max_new_tokens=512,
151
  do_sample=False,
152
  )
153
 
 
162
  if isinstance(parsed, dict) and "raw_output" not in parsed:
163
  parsed["_raw_output"] = generated_text
164
 
165
+ return prepared, parsed
 
 
 
 
 
 
 
 
 
 
 
166
 
167
 
168
  with gr.Blocks() as demo:
169
  gr.Markdown("# Pantry Scanner")
170
  gr.Markdown(
171
+ "Single-image Qwen3-VL OCR/brand reader. No tiling, no crop pipeline, no manual sharpening."
172
  )
173
 
174
  with gr.Row():
 
178
  analyze_btn = gr.Button("Analyze", variant="primary")
179
 
180
  with gr.Row():
181
+ prepared_output = gr.Image(type="pil", label="Feeding image")
182
  output_json = gr.JSON(label="Detected items")
183
 
184
  analyze_btn.click(