"""Progress validation agent: compare cooking photo against target step.""" from __future__ import annotations import logging from typing import Optional import spaces import torch from PIL import Image from src import config from src.agents.mise_en_place import model, processor from src.agents.recipe_planner import _extract_json log = logging.getLogger(__name__) _VALIDATOR_PROMPT = (config.PROMPTS_DIR / "validator_prompt.txt").read_text(encoding="utf-8") @spaces.GPU(duration=45) def validate(image: Optional[Image.Image], step_instruction: str) -> dict: """Compare a cooking-progress photo to the target step description. Returns a dict with keys: verdict ('go'|'wait'|'fix'), feedback, tip. """ if image is None: return { "verdict": "wait", "feedback": "No image provided.", "tip": "Upload a photo of your cooking progress to get feedback.", } try: img = image.convert("RGB") prompt = _VALIDATOR_PROMPT.replace("{step_instruction}", step_instruction) messages = [{"role": "user", "content": [ {"type": "image", "image": img}, {"type": "text", "text": prompt}, ]}] inputs = processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", enable_thinking=False, processor_kwargs={"downsample_mode": "16x", "max_slice_nums": 9, "use_image_id": True}, ) device = model.device inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()} for k, v in inputs.items(): if isinstance(v, torch.Tensor) and torch.is_floating_point(v): inputs[k] = v.to(dtype=torch.bfloat16) with torch.no_grad(): generated_ids = model.generate( **inputs, max_new_tokens=256, do_sample=False, downsample_mode="16x", ) trimmed = [out[len(inp):] for inp, out in zip(inputs["input_ids"], generated_ids)] raw = processor.batch_decode(trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] log.info("validate raw: %s", raw[:400]) data = _extract_json(raw) verdict = str(data.get("verdict", "wait")) if verdict not in ("go", "wait", "fix"): verdict = "wait" return { "verdict": verdict, "feedback": str(data.get("feedback", "")), "tip": str(data.get("tip", "")), } except Exception as exc: log.warning("validate failed: %s", exc) return { "verdict": "wait", "feedback": "Could not analyse the photo.", "tip": "Make sure the image is well-lit and in focus.", }