airzy1 commited on
Commit
2f43bdf
·
verified ·
1 Parent(s): ffbf7f2

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +153 -0
app.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import re
4
+
5
+ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
6
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,max_split_size_mb:128"
7
+ os.environ["HF_HOME"] = "/tmp/hf"
8
+ os.environ["HF_HUB_CACHE"] = "/tmp/hf/hub"
9
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf/transformers"
10
+
11
+ import spaces
12
+ import torch
13
+ import gradio as gr
14
+ from PIL import Image
15
+ from transformers import AutoProcessor, AutoModelForImageTextToText
16
+
17
+ MODEL_ID = "Qwen/Qwen3.5-397B-A17B"
18
+
19
+ processor = None
20
+ model = None
21
+
22
+
23
+ def load_model():
24
+ global processor, model
25
+ if model is not None and processor is not None:
26
+ return
27
+
28
+ processor = AutoProcessor.from_pretrained(MODEL_ID)
29
+ model = AutoModelForImageTextToText.from_pretrained(
30
+ MODEL_ID,
31
+ device_map="auto",
32
+ torch_dtype="auto",
33
+ )
34
+ model.eval()
35
+
36
+
37
+ def extract_json(text: str):
38
+ text = (text or "").strip()
39
+
40
+ try:
41
+ return json.loads(text)
42
+ except Exception:
43
+ pass
44
+
45
+ match = re.search(r"\{.*\}", text, flags=re.S)
46
+ if match:
47
+ try:
48
+ return json.loads(match.group(0))
49
+ except Exception:
50
+ pass
51
+
52
+ return {"raw_output": text}
53
+
54
+
55
+ PROMPT = """Analyze this pantry image.
56
+ Return ONLY valid JSON with this schema:
57
+ {
58
+ "items": [
59
+ {
60
+ "name": "",
61
+ "brand": "",
62
+ "category": "",
63
+ "package_type": "",
64
+ "estimated_quantity": "",
65
+ "evidence": "",
66
+ "confidence": 0.0
67
+ }
68
+ ],
69
+ "summary": "",
70
+ "uncertain_items": []
71
+ }
72
+
73
+ Rules:
74
+ - List visible pantry foods, ingredients, drinks, and packaged items.
75
+ - Use the smallest sensible item name.
76
+ - Do not invent hidden ingredients.
77
+ - If a brand is unclear, leave brand empty.
78
+ - If uncertain, lower confidence.
79
+ - Do not include markdown, code fences, or commentary.
80
+ """
81
+
82
+
83
+ @spaces.GPU(size="large", duration=60)
84
+ def analyze_pantry(image: Image.Image):
85
+ if image is None:
86
+ return {"error": "Please upload a pantry image."}
87
+
88
+ load_model()
89
+
90
+ messages = [
91
+ {
92
+ "role": "system",
93
+ "content": [
94
+ {"type": "text", "text": "You extract pantry items from photos and respond with JSON only."}
95
+ ],
96
+ },
97
+ {
98
+ "role": "user",
99
+ "content": [
100
+ {"type": "image", "image": image.convert("RGB")},
101
+ {"type": "text", "text": PROMPT},
102
+ ],
103
+ },
104
+ ]
105
+
106
+ inputs = processor.apply_chat_template(
107
+ messages,
108
+ add_generation_prompt=True,
109
+ tokenize=True,
110
+ return_dict=True,
111
+ return_tensors="pt",
112
+ )
113
+
114
+ inputs = {k: v.to(model.device) if hasattr(v, "to") else v for k, v in inputs.items()}
115
+
116
+ with torch.inference_mode():
117
+ output_ids = model.generate(
118
+ **inputs,
119
+ max_new_tokens=1200,
120
+ do_sample=False,
121
+ )
122
+
123
+ prompt_len = inputs["input_ids"].shape[-1]
124
+ generated_text = processor.decode(
125
+ output_ids[0][prompt_len:],
126
+ skip_special_tokens=True,
127
+ ).strip()
128
+
129
+ parsed = extract_json(generated_text)
130
+ if isinstance(parsed, dict) and "raw_output" not in parsed:
131
+ parsed["_raw_output"] = generated_text
132
+ return parsed
133
+
134
+
135
+ @spaces.GPU(size="large", duration=1)
136
+ def cloud():
137
+ return None
138
+
139
+
140
+ with gr.Blocks() as demo:
141
+ gr.Markdown("# Pantry ingredient / item extractor")
142
+
143
+ image_input = gr.Image(type="pil", label="Pantry image")
144
+ analyze_btn = gr.Button("Analyze")
145
+ cloud_btn = gr.Button("Cloud")
146
+ output_json = gr.JSON(label="Output")
147
+
148
+ analyze_btn.click(analyze_pantry, inputs=[image_input], outputs=[output_json], api_name="analyze")
149
+ cloud_btn.click(cloud, inputs=[], outputs=[], api_name="cloud")
150
+
151
+
152
+ demo.queue(max_size=16)
153
+ demo.launch(ssr_mode=False)