ISR / models /isr /assessor.py
Zhen Ye
feat: improve ISR mission planner and assessor prompts
0ebc456
import logging
import numpy as np
from baml_client.sync_client import b as baml
from baml_client.types import DetectionInfo
from baml_py import Image
from models.isr.utils import encode_frame
logger = logging.getLogger(__name__)
class ISRAssessor:
"""Assesses tracked detections against a mission objective using BAML + GPT-4o-mini vision."""
def __init__(self, mission: str):
self.mission = mission
def assess_batch_sync(self, tracks: list[dict], frame: np.ndarray) -> dict:
"""
Assess a batch of tracks against the mission via BAML.
Returns:
Dict mapping track_id -> {mission_relevant, satisfies, reason, features}
"""
if not tracks:
return {}
h, w = frame.shape[:2]
detections = []
for t in tracks:
bbox = t.get("bbox", [0, 0, 0, 0])
cx = ((bbox[0] + bbox[2]) / 2) / w if w > 0 else 0.0
cy = ((bbox[1] + bbox[3]) / 2) / h if h > 0 else 0.0
direction = t.get("direction_clock", "") or ""
speed = round(t.get("speed_kph", 0) or 0, 1)
if speed == 0.0 and not direction:
direction = "stationary"
detections.append(DetectionInfo(
track_id=str(t.get("track_id", "?")),
class_label=t.get("label", "unknown"),
confidence=round(t.get("score", 0) or 0, 2),
center_x=round(cx, 2),
center_y=round(cy, 2),
speed_kph=speed,
direction=direction or "stationary",
))
frame_b64 = encode_frame(frame, max_dim=1024, quality=60)
if not frame_b64:
logger.warning("Failed to encode frame for BAML assessment")
return {}
frame_image = Image.from_base64("image/jpeg", frame_b64)
logger.info(
"[ISR Input] mission=%r detections=%s",
self.mission,
[{"track_id": d.track_id, "class": d.class_label, "speed": d.speed_kph, "dir": d.direction} for d in detections],
)
try:
verdicts = baml.AssessDetections(
mission=self.mission,
detections=detections,
frame_image=frame_image,
)
result = {}
for v in verdicts:
tid = v.track_id
if tid:
result[tid] = {
"mission_relevant": v.mission_relevant,
"satisfies": v.satisfies,
"reason": v.reason or "",
"features": dict(v.features) if v.features else {},
"assessment_status": "ASSESSED",
}
logger.info(
"[ISR Verdict] track=%s relevant=%s satisfies=%s reason=%r features=%s",
tid, v.mission_relevant, v.satisfies, v.reason,
dict(v.features) if v.features else {},
)
return result
except Exception:
logger.exception("BAML ISR assessment failed")
return {}