cheapsake / app.py
airzy1's picture
Update app.py
87c7a79 verified
import os
import json
import re
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,max_split_size_mb:128"
# Writable cache path for Spaces without persistent storage
os.environ["HF_HOME"] = "/tmp/hf"
os.environ["HF_HUB_CACHE"] = "/tmp/hf/hub"
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf/transformers"
os.makedirs("/tmp/hf/hub", exist_ok=True)
os.makedirs("/tmp/hf/transformers", exist_ok=True)
import spaces
import torch
import gradio as gr
from PIL import Image
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
HF_TOKEN = os.environ.get("HF_TOKEN", "")
MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct"
processor = None
model = None
def load_model():
global processor, model
if model is not None and processor is not None:
return
print("loading processor")
processor = AutoProcessor.from_pretrained(
MODEL_ID,
token=HF_TOKEN if HF_TOKEN else None,
min_pixels=256 * 28 * 28,
max_pixels=1024 * 28 * 28,
)
print("loading model:", MODEL_ID)
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
MODEL_ID,
token=HF_TOKEN if HF_TOKEN else None,
device_map="auto",
torch_dtype="auto",
)
print("model.eval started")
model.eval()
def extract_json(text: str):
text = (text or "").strip()
try:
return json.loads(text)
except Exception:
pass
match = re.search(r"\{.*\}", text, flags=re.S)
if match:
try:
return json.loads(match.group(0))
except Exception:
pass
return {"raw_output": text}
PROMPT = """Analyze this pantry image.
Return ONLY valid JSON with this schema:
{
"items": [
{
"name": "",
"brand": "",
"quantity": "",
"confidence": 0.0
}
],
"summary": "",
"uncertain_items": []
}
find all the unique recipes in detail
"""
@spaces.GPU(size="xlarge", duration=160)
def analyze_pantry(image: Image.Image):
if image is None:
return {"error": "Please upload a pantry image."}
load_model()
messages = [
{
"role": "system",
"content": [
{"type": "text", "text": "You extract pantry items from photos and respond with JSON only."}
],
},
{
"role": "user",
"content": [
{"type": "image", "image": image.convert("RGB")},
{"type": "text", "text": PROMPT},
],
},
]
text = processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
inputs = processor(
text=[text],
images=[image.convert("RGB")],
padding=True,
return_tensors="pt",
)
inputs = {k: v.to(model.device) if hasattr(v, "to") else v for k, v in inputs.items()}
with torch.inference_mode():
output_ids = model.generate(
**inputs,
max_new_tokens=800,
do_sample=False,
)
prompt_len = inputs["input_ids"].shape[-1]
generated_text = processor.batch_decode(
[output_ids[0][prompt_len:]],
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)[0].strip()
print("generated_text:", generated_text)
parsed = extract_json(generated_text)
if isinstance(parsed, dict) and "raw_output" not in parsed:
parsed["_raw_output"] = generated_text
return parsed
@spaces.GPU(size="large", duration=1)
def cloud():
return None
with gr.Blocks() as demo:
gr.Markdown("# Pantry ingredient / item extractor")
image_input = gr.Image(type="pil", label="Pantry image")
analyze_btn = gr.Button("Analyze")
cloud_btn = gr.Button("Cloud")
output_json = gr.JSON(label="Output")
analyze_btn.click(analyze_pantry, inputs=[image_input], outputs=[output_json], api_name="analyze")
cloud_btn.click(cloud, inputs=[], outputs=[], api_name="cloud")
demo.queue(max_size=16)
demo.launch(ssr_mode=False)