File size: 5,140 Bytes
a2ca6f9
 
 
624478a
 
 
a2ca6f9
624478a
a2ca6f9
 
 
 
624478a
a2ca6f9
624478a
a2ca6f9
bb6e650
 
a2ca6f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
624478a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bb6e650
624478a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55e372a
624478a
 
 
 
 
 
 
 
 
 
bb6e650
 
624478a
 
 
 
 
 
 
 
 
 
 
 
 
 
bb6e650
624478a
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
"""
Object relevance evaluation β€” deterministic gate between detection and GPT assessment.

Public functions:
  evaluate_relevance(detection, criteria) -> RelevanceDecision  (deterministic)
  evaluate_relevance_llm(detected_labels, mission_text) -> set[str]  (LLM post-filter)

INVARIANT INV-13 enforcement: evaluate_relevance() accepts RelevanceCriteria, NOT
MissionSpecification. It cannot see context_phrases, stripped_modifiers, or any
LLM-derived field. This is structural, not by convention.
"""

import json
import logging
from typing import Any, Dict, List, NamedTuple, Set

from utils.openai_client import chat_completion, extract_content, get_api_key, OpenAIAPIError

from coco_classes import canonicalize_coco_name
from utils.schemas import RelevanceCriteria

logger = logging.getLogger(__name__)


class RelevanceDecision(NamedTuple):
    relevant: bool
    reason: str  # "ok" | "label_not_in_required_classes" | "below_confidence"


def evaluate_relevance(
    detection: Dict[str, Any],
    criteria: RelevanceCriteria,
) -> RelevanceDecision:
    """Evaluate whether a detection is relevant to the mission.

    Pure deterministic predicate β€” no LLM involvement.

    Args:
        detection: Detection dict with at least 'label' and 'score' keys.
        criteria: RelevanceCriteria with required_classes and min_confidence.

    Returns:
        RelevanceDecision(relevant=bool, reason=str).
    """
    label = (detection.get("label") or "").lower().strip()
    confidence = detection.get("score", 0.0)

    if not label:
        return RelevanceDecision(False, "label_not_in_required_classes")

    # Build lowercase set of required classes for comparison
    required_lower = {c.lower() for c in criteria.required_classes}

    # Direct match
    if label in required_lower:
        if confidence < criteria.min_confidence:
            return RelevanceDecision(False, "below_confidence")
        return RelevanceDecision(True, "ok")

    # Synonym match via COCO canonicalization
    canonical = canonicalize_coco_name(label)
    if canonical and canonical.lower() in required_lower:
        if confidence < criteria.min_confidence:
            return RelevanceDecision(False, "below_confidence")
        return RelevanceDecision(True, "ok")

    # Check if any required class canonicalizes to the same COCO class as the label
    if canonical:
        for req in criteria.required_classes:
            req_canonical = canonicalize_coco_name(req)
            if req_canonical and req_canonical.lower() == canonical.lower():
                if confidence < criteria.min_confidence:
                    return RelevanceDecision(False, "below_confidence")
                return RelevanceDecision(True, "ok")

    return RelevanceDecision(False, "label_not_in_required_classes")


def evaluate_relevance_llm(
    detected_labels: List[str],
    mission_text: str,
) -> Set[str]:
    """Ask GPT which detected labels are relevant to the mission.

    Called ONCE on frame 0 with the unique labels found by the detector.
    Returns a set of relevant label strings (lowercased).

    On API failure, falls back to accepting all labels (fail-open, logged).
    """
    if not detected_labels:
        return set()

    if not get_api_key():
        logger.warning(
            "OPENAI_API_KEY not set β€” LLM relevance filter falling back to accept-all"
        )
        return set(detected_labels)

    prompt = (
        f"Given this mission: \"{mission_text}\"\n\n"
        f"Which of these detected object classes are relevant to the mission?\n"
        f"{json.dumps(detected_labels)}\n\n"
        "Return JSON: {\"relevant_labels\": [...]}\n"
        "Only include labels from the provided list that are relevant to "
        "accomplishing the mission. Be inclusive β€” if in doubt, include it."
    )

    payload = {
        "model": "gpt-4o-mini",
        "temperature": 0.0,
        "max_tokens": 200,
        "response_format": {"type": "json_object"},
        "messages": [
            {"role": "system", "content": "You are a mission relevance filter. Return only JSON."},
            {"role": "user", "content": prompt},
        ],
    }

    try:
        resp_data = chat_completion(payload)
        content, _refusal = extract_content(resp_data)
        if not content:
            logger.warning("GPT returned empty content for relevance filter β€” accept-all")
            return set(detected_labels)

        result = json.loads(content)
        relevant = result.get("relevant_labels", detected_labels)
        relevant_set = {label.lower() for label in relevant}

        logger.info(
            "LLM relevance filter: mission=%r detected=%s relevant=%s",
            mission_text, detected_labels, relevant_set,
        )
        return relevant_set

    except OpenAIAPIError as e:
        logger.warning("LLM relevance API call failed: %s β€” accept-all fallback", e)
        return set(detected_labels)
    except (json.JSONDecodeError, KeyError, TypeError) as e:
        logger.warning("LLM relevance response parse failed: %s β€” accept-all fallback", e)
        return set(detected_labels)