Spaces:
Running on Zero
Running on Zero
File size: 2,907 Bytes
75c5414 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 | """Progress validation agent: compare cooking photo against target step."""
from __future__ import annotations
import logging
from typing import Optional
import spaces
import torch
from PIL import Image
from src import config
from src.agents.mise_en_place import model, processor
from src.agents.recipe_planner import _extract_json
log = logging.getLogger(__name__)
_VALIDATOR_PROMPT = (config.PROMPTS_DIR / "validator_prompt.txt").read_text(encoding="utf-8")
@spaces.GPU(duration=45)
def validate(image: Optional[Image.Image], step_instruction: str) -> dict:
"""Compare a cooking-progress photo to the target step description.
Returns a dict with keys: verdict ('go'|'wait'|'fix'), feedback, tip.
"""
if image is None:
return {
"verdict": "wait",
"feedback": "No image provided.",
"tip": "Upload a photo of your cooking progress to get feedback.",
}
try:
img = image.convert("RGB")
prompt = _VALIDATOR_PROMPT.replace("{step_instruction}", step_instruction)
messages = [{"role": "user", "content": [
{"type": "image", "image": img},
{"type": "text", "text": prompt},
]}]
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
enable_thinking=False,
processor_kwargs={"downsample_mode": "16x", "max_slice_nums": 9, "use_image_id": True},
)
device = model.device
inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
for k, v in inputs.items():
if isinstance(v, torch.Tensor) and torch.is_floating_point(v):
inputs[k] = v.to(dtype=torch.bfloat16)
with torch.no_grad():
generated_ids = model.generate(
**inputs,
max_new_tokens=256,
do_sample=False,
downsample_mode="16x",
)
trimmed = [out[len(inp):] for inp, out in zip(inputs["input_ids"], generated_ids)]
raw = processor.batch_decode(trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
log.info("validate raw: %s", raw[:400])
data = _extract_json(raw)
verdict = str(data.get("verdict", "wait"))
if verdict not in ("go", "wait", "fix"):
verdict = "wait"
return {
"verdict": verdict,
"feedback": str(data.get("feedback", "")),
"tip": str(data.get("tip", "")),
}
except Exception as exc:
log.warning("validate failed: %s", exc)
return {
"verdict": "wait",
"feedback": "Could not analyse the photo.",
"tip": "Make sure the image is well-lit and in focus.",
}
|