File size: 2,907 Bytes
75c5414
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
"""Progress validation agent: compare cooking photo against target step."""
from __future__ import annotations

import logging
from typing import Optional

import spaces
import torch
from PIL import Image

from src import config
from src.agents.mise_en_place import model, processor
from src.agents.recipe_planner import _extract_json

log = logging.getLogger(__name__)

_VALIDATOR_PROMPT = (config.PROMPTS_DIR / "validator_prompt.txt").read_text(encoding="utf-8")


@spaces.GPU(duration=45)
def validate(image: Optional[Image.Image], step_instruction: str) -> dict:
    """Compare a cooking-progress photo to the target step description.

    Returns a dict with keys: verdict ('go'|'wait'|'fix'), feedback, tip.
    """
    if image is None:
        return {
            "verdict": "wait",
            "feedback": "No image provided.",
            "tip": "Upload a photo of your cooking progress to get feedback.",
        }
    try:
        img = image.convert("RGB")
        prompt = _VALIDATOR_PROMPT.replace("{step_instruction}", step_instruction)

        messages = [{"role": "user", "content": [
            {"type": "image", "image": img},
            {"type": "text", "text": prompt},
        ]}]

        inputs = processor.apply_chat_template(
            messages,
            add_generation_prompt=True,
            tokenize=True,
            return_dict=True,
            return_tensors="pt",
            enable_thinking=False,
            processor_kwargs={"downsample_mode": "16x", "max_slice_nums": 9, "use_image_id": True},
        )
        device = model.device
        inputs = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
        for k, v in inputs.items():
            if isinstance(v, torch.Tensor) and torch.is_floating_point(v):
                inputs[k] = v.to(dtype=torch.bfloat16)

        with torch.no_grad():
            generated_ids = model.generate(
                **inputs,
                max_new_tokens=256,
                do_sample=False,
                downsample_mode="16x",
            )

        trimmed = [out[len(inp):] for inp, out in zip(inputs["input_ids"], generated_ids)]
        raw = processor.batch_decode(trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
        log.info("validate raw: %s", raw[:400])

        data = _extract_json(raw)
        verdict = str(data.get("verdict", "wait"))
        if verdict not in ("go", "wait", "fix"):
            verdict = "wait"

        return {
            "verdict": verdict,
            "feedback": str(data.get("feedback", "")),
            "tip": str(data.get("tip", "")),
        }
    except Exception as exc:
        log.warning("validate failed: %s", exc)
        return {
            "verdict": "wait",
            "feedback": "Could not analyse the photo.",
            "tip": "Make sure the image is well-lit and in focus.",
        }