Spaces:
Running on Zero
Running on Zero
| """Progress validation agent: compare cooking photo against target step.""" | |
| from __future__ import annotations | |
| import logging | |
| from typing import Optional | |
| import spaces | |
| import torch | |
| from PIL import Image | |
| from src import config | |
| from src.agents.mise_en_place import model, processor | |
| from src.agents.recipe_planner import _extract_json | |
| log = logging.getLogger(__name__) | |
| _VALIDATOR_PROMPT = (config.PROMPTS_DIR / "validator_prompt.txt").read_text(encoding="utf-8") | |
| def validate(image: Optional[Image.Image], step_instruction: str) -> dict: | |
| """Compare a cooking-progress photo to the target step description. | |
| Returns a dict with keys: verdict ('go'|'wait'|'fix'), feedback, tip. | |
| """ | |
| if image is None: | |
| return { | |
| "verdict": "wait", | |
| "feedback": "No image provided.", | |
| "tip": "Upload a photo of your cooking progress to get feedback.", | |
| } | |
| try: | |
| img = image.convert("RGB") | |
| prompt = _VALIDATOR_PROMPT.replace("{step_instruction}", step_instruction) | |
| messages = [{"role": "user", "content": [ | |
| {"type": "image", "image": img}, | |
| {"type": "text", "text": prompt}, | |
| ]}] | |
| inputs = processor.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| return_dict=True, | |
| return_tensors="pt", | |
| enable_thinking=False, | |
| processor_kwargs={"downsample_mode": "16x", "max_slice_nums": 9, "use_image_id": True}, | |
| ) | |
| device = model.device | |
| inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()} | |
| for k, v in inputs.items(): | |
| if isinstance(v, torch.Tensor) and torch.is_floating_point(v): | |
| inputs[k] = v.to(dtype=torch.bfloat16) | |
| with torch.no_grad(): | |
| generated_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=256, | |
| do_sample=False, | |
| downsample_mode="16x", | |
| ) | |
| trimmed = [out[len(inp):] for inp, out in zip(inputs["input_ids"], generated_ids)] | |
| raw = processor.batch_decode(trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] | |
| log.info("validate raw: %s", raw[:400]) | |
| data = _extract_json(raw) | |
| verdict = str(data.get("verdict", "wait")) | |
| if verdict not in ("go", "wait", "fix"): | |
| verdict = "wait" | |
| return { | |
| "verdict": verdict, | |
| "feedback": str(data.get("feedback", "")), | |
| "tip": str(data.get("tip", "")), | |
| } | |
| except Exception as exc: | |
| log.warning("validate failed: %s", exc) | |
| return { | |
| "verdict": "wait", | |
| "feedback": "Could not analyse the photo.", | |
| "tip": "Make sure the image is well-lit and in focus.", | |
| } | |