airzy1 commited on
Commit
d4de43b
·
verified ·
1 Parent(s): 537537a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -45
app.py CHANGED
@@ -1,11 +1,22 @@
1
  import os
2
  import json
3
  import re
 
4
 
 
 
 
 
 
 
 
 
 
 
5
  os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
6
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,max_split_size_mb:64"
7
 
8
- # Writable cache
9
  os.environ["HF_HOME"] = "/tmp/hf"
10
  os.environ["HF_HUB_CACHE"] = "/tmp/hf/hub"
11
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf/transformers"
@@ -13,65 +24,91 @@ os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf/transformers"
13
  os.makedirs("/tmp/hf/hub", exist_ok=True)
14
  os.makedirs("/tmp/hf/transformers", exist_ok=True)
15
 
16
- import spaces
17
- import torch
18
- import gradio as gr
19
- from PIL import Image
20
- from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
21
 
22
  HF_TOKEN = os.environ.get("HF_TOKEN", "")
23
 
24
- # SMALLER MODEL
25
- MODEL_ID = "Qwen/Qwen2.5-VL-1.8B-Instruct"
 
 
 
 
26
 
27
  processor = None
28
  model = None
29
 
30
 
31
- def load_model():
32
  global processor, model
33
- if model is not None:
34
  return
35
 
36
  print("Loading processor...")
37
  processor = AutoProcessor.from_pretrained(
38
  MODEL_ID,
39
  token=HF_TOKEN if HF_TOKEN else None,
40
- min_pixels=256 * 28 * 28,
41
- max_pixels=768 * 28 * 28, # 🔥 lower for memory safety
42
  )
43
 
44
- print("Loading model:", MODEL_ID)
45
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
46
  MODEL_ID,
47
  token=HF_TOKEN if HF_TOKEN else None,
48
  device_map="auto",
49
- torch_dtype=torch.float16, # 🔥 force lower memory
 
50
  )
51
 
52
  model.eval()
53
  print("Model ready")
54
 
55
 
56
- def extract_json(text: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  text = (text or "").strip()
58
 
 
 
 
 
59
  try:
60
  return json.loads(text)
61
- except:
62
  pass
63
 
 
64
  match = re.search(r"\{.*\}", text, flags=re.S)
65
  if match:
66
  try:
67
  return json.loads(match.group(0))
68
- except:
69
  pass
70
 
71
  return {"raw_output": text}
72
 
73
 
74
- PROMPT = """Analyze this pantry image.
75
 
76
  Return ONLY valid JSON with this schema:
77
  {
@@ -87,25 +124,23 @@ Return ONLY valid JSON with this schema:
87
  "uncertain_items": []
88
  }
89
 
90
- Focus on:
91
- - canned goods
92
- - labels
93
- - jars
94
- - boxes
95
- - spices
96
-
97
- Be precise. Do NOT hallucinate.
98
  """
99
 
100
 
101
- @spaces.GPU(size="large", duration=120) # 🔥 smaller GPU = faster queue
102
  def analyze_pantry(image: Image.Image):
103
  if image is None:
104
- return {"error": "Upload an image"}
105
 
106
  load_model()
107
 
108
- image = image.convert("RGB")
109
 
110
  messages = [
111
  {
@@ -127,50 +162,46 @@ def analyze_pantry(image: Image.Image):
127
  add_generation_prompt=True,
128
  )
129
 
 
 
130
  inputs = processor(
131
  text=[text],
132
- images=[image],
 
133
  padding=True,
134
  return_tensors="pt",
135
  )
136
 
137
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
138
 
139
  with torch.inference_mode():
140
  output_ids = model.generate(
141
  **inputs,
142
- max_new_tokens=500, # 🔥 reduced
143
  do_sample=False,
144
  )
145
 
146
  prompt_len = inputs["input_ids"].shape[-1]
147
-
148
  generated_text = processor.batch_decode(
149
  [output_ids[0][prompt_len:]],
150
  skip_special_tokens=True,
151
  )[0].strip()
152
 
153
- print("OUTPUT:", generated_text)
154
-
155
  parsed = extract_json(generated_text)
156
-
157
  if isinstance(parsed, dict) and "raw_output" not in parsed:
158
  parsed["_raw_output"] = generated_text
159
 
160
  return parsed
161
 
162
 
163
- @spaces.GPU(size="small", duration=1)
164
- def cloud():
165
- return None
166
-
167
-
168
  with gr.Blocks() as demo:
169
- gr.Markdown("# Pantry Analyzer (ZeroGPU Optimized)")
170
 
171
- image_input = gr.Image(type="pil")
172
- analyze_btn = gr.Button("Analyze")
173
- output_json = gr.JSON()
 
 
174
 
175
  analyze_btn.click(analyze_pantry, inputs=image_input, outputs=output_json)
176
 
 
1
  import os
2
  import json
3
  import re
4
+ from typing import Any, Dict
5
 
6
+ import torch
7
+ import gradio as gr
8
+ import spaces
9
+ from PIL import Image, ImageFilter, ImageOps
10
+ from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
11
+ from qwen_vl_utils import process_vision_info
12
+
13
+ # ---------------------------
14
+ # Environment / cache setup
15
+ # ---------------------------
16
  os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
17
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,max_split_size_mb:64"
18
 
19
+ # Writable cache for Spaces
20
  os.environ["HF_HOME"] = "/tmp/hf"
21
  os.environ["HF_HUB_CACHE"] = "/tmp/hf/hub"
22
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf/transformers"
 
24
  os.makedirs("/tmp/hf/hub", exist_ok=True)
25
  os.makedirs("/tmp/hf/transformers", exist_ok=True)
26
 
27
+ torch.set_float32_matmul_precision("high")
 
 
 
 
28
 
29
  HF_TOKEN = os.environ.get("HF_TOKEN", "")
30
 
31
+ # Heaviest practical choice for a ZeroGPU Space
32
+ MODEL_ID = "Qwen/Qwen2.5-VL-72B-Instruct-AWQ"
33
+
34
+ # Aggressive visual token budget for tiny labels / ingredients
35
+ MIN_PIXELS = 1024 * 28 * 28
36
+ MAX_PIXELS = 4096 * 28 * 28
37
 
38
  processor = None
39
  model = None
40
 
41
 
42
+ def load_model() -> None:
43
  global processor, model
44
+ if model is not None and processor is not None:
45
  return
46
 
47
  print("Loading processor...")
48
  processor = AutoProcessor.from_pretrained(
49
  MODEL_ID,
50
  token=HF_TOKEN if HF_TOKEN else None,
51
+ min_pixels=MIN_PIXELS,
52
+ max_pixels=MAX_PIXELS,
53
  )
54
 
55
+ print("Loading model...")
56
  model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
57
  MODEL_ID,
58
  token=HF_TOKEN if HF_TOKEN else None,
59
  device_map="auto",
60
+ torch_dtype="auto",
61
+ low_cpu_mem_usage=True,
62
  )
63
 
64
  model.eval()
65
  print("Model ready")
66
 
67
 
68
+ def prepare_image(image: Image.Image) -> Image.Image:
69
+ """Upscale and sharpen to help with tiny text on pantry labels."""
70
+ image = ImageOps.exif_transpose(image).convert("RGB")
71
+
72
+ # Upscale small images so tiny labels have a better chance of being read.
73
+ long_side = max(image.size)
74
+ target_long_side = 2400
75
+ if long_side < target_long_side:
76
+ scale = target_long_side / long_side
77
+ new_size = (
78
+ max(1, int(round(image.width * scale))),
79
+ max(1, int(round(image.height * scale))),
80
+ )
81
+ image = image.resize(new_size, Image.Resampling.LANCZOS)
82
+
83
+ image = ImageOps.autocontrast(image)
84
+ image = image.filter(ImageFilter.SHARPEN)
85
+ return image
86
+
87
+
88
+ def extract_json(text: str) -> Dict[str, Any]:
89
  text = (text or "").strip()
90
 
91
+ # Strip common markdown fences
92
+ text = re.sub(r"^\s*```(?:json)?\s*", "", text, flags=re.I)
93
+ text = re.sub(r"\s*```\s*$", "", text, flags=re.I)
94
+
95
  try:
96
  return json.loads(text)
97
+ except Exception:
98
  pass
99
 
100
+ # Try to find the first JSON object in the text
101
  match = re.search(r"\{.*\}", text, flags=re.S)
102
  if match:
103
  try:
104
  return json.loads(match.group(0))
105
+ except Exception:
106
  pass
107
 
108
  return {"raw_output": text}
109
 
110
 
111
+ PROMPT = """Analyze this pantry image carefully.
112
 
113
  Return ONLY valid JSON with this schema:
114
  {
 
124
  "uncertain_items": []
125
  }
126
 
127
+ Rules:
128
+ - Focus on tiny labels, ingredient names, canned goods, jars, boxes, spices, and packaging text.
129
+ - Prefer exact visible text over guesses.
130
+ - If a brand or quantity is unclear, leave it empty or put it in uncertain_items.
131
+ - Do not hallucinate.
132
+ - Return JSON only. No markdown, no explanation, no code fences.
 
 
133
  """
134
 
135
 
136
+ @spaces.GPU(size="xlarge", duration=120)
137
  def analyze_pantry(image: Image.Image):
138
  if image is None:
139
+ return {"error": "Upload an image first."}
140
 
141
  load_model()
142
 
143
+ image = prepare_image(image)
144
 
145
  messages = [
146
  {
 
162
  add_generation_prompt=True,
163
  )
164
 
165
+ image_inputs, video_inputs = process_vision_info(messages)
166
+
167
  inputs = processor(
168
  text=[text],
169
+ images=image_inputs,
170
+ videos=video_inputs,
171
  padding=True,
172
  return_tensors="pt",
173
  )
174
 
175
+ inputs = inputs.to(model.device)
176
 
177
  with torch.inference_mode():
178
  output_ids = model.generate(
179
  **inputs,
180
+ max_new_tokens=700,
181
  do_sample=False,
182
  )
183
 
184
  prompt_len = inputs["input_ids"].shape[-1]
 
185
  generated_text = processor.batch_decode(
186
  [output_ids[0][prompt_len:]],
187
  skip_special_tokens=True,
188
  )[0].strip()
189
 
 
 
190
  parsed = extract_json(generated_text)
 
191
  if isinstance(parsed, dict) and "raw_output" not in parsed:
192
  parsed["_raw_output"] = generated_text
193
 
194
  return parsed
195
 
196
 
 
 
 
 
 
197
  with gr.Blocks() as demo:
198
+ gr.Markdown("# Pantry Scanner")
199
 
200
+ with gr.Row():
201
+ image_input = gr.Image(type="pil", label="Pantry image")
202
+ with gr.Row():
203
+ analyze_btn = gr.Button("Analyze", variant="primary")
204
+ output_json = gr.JSON(label="Detected items")
205
 
206
  analyze_btn.click(analyze_pantry, inputs=image_input, outputs=output_json)
207