Spaces:
Paused
Paused
| import logging | |
| import numpy as np | |
| from baml_client.sync_client import b as baml | |
| from baml_client.types import DetectionInfo | |
| from baml_py import Image | |
| from models.isr.utils import encode_frame | |
| logger = logging.getLogger(__name__) | |
| class ISRAssessor: | |
| """Assesses tracked detections against a mission objective using BAML + GPT-4o-mini vision.""" | |
| def __init__(self, mission: str): | |
| self.mission = mission | |
| def assess_batch_sync(self, tracks: list[dict], frame: np.ndarray) -> dict: | |
| """ | |
| Assess a batch of tracks against the mission via BAML. | |
| Returns: | |
| Dict mapping track_id -> {mission_relevant, satisfies, reason, features} | |
| """ | |
| if not tracks: | |
| return {} | |
| h, w = frame.shape[:2] | |
| detections = [] | |
| for t in tracks: | |
| bbox = t.get("bbox", [0, 0, 0, 0]) | |
| cx = ((bbox[0] + bbox[2]) / 2) / w if w > 0 else 0.0 | |
| cy = ((bbox[1] + bbox[3]) / 2) / h if h > 0 else 0.0 | |
| direction = t.get("direction_clock", "") or "" | |
| speed = round(t.get("speed_kph", 0) or 0, 1) | |
| if speed == 0.0 and not direction: | |
| direction = "stationary" | |
| detections.append(DetectionInfo( | |
| track_id=str(t.get("track_id", "?")), | |
| class_label=t.get("label", "unknown"), | |
| confidence=round(t.get("score", 0) or 0, 2), | |
| center_x=round(cx, 2), | |
| center_y=round(cy, 2), | |
| speed_kph=speed, | |
| direction=direction or "stationary", | |
| )) | |
| frame_b64 = encode_frame(frame, max_dim=1024, quality=60) | |
| if not frame_b64: | |
| logger.warning("Failed to encode frame for BAML assessment") | |
| return {} | |
| frame_image = Image.from_base64("image/jpeg", frame_b64) | |
| logger.info( | |
| "[ISR Input] mission=%r detections=%s", | |
| self.mission, | |
| [{"track_id": d.track_id, "class": d.class_label, "speed": d.speed_kph, "dir": d.direction} for d in detections], | |
| ) | |
| try: | |
| verdicts = baml.AssessDetections( | |
| mission=self.mission, | |
| detections=detections, | |
| frame_image=frame_image, | |
| ) | |
| result = {} | |
| for v in verdicts: | |
| tid = v.track_id | |
| if tid: | |
| result[tid] = { | |
| "mission_relevant": v.mission_relevant, | |
| "satisfies": v.satisfies, | |
| "reason": v.reason or "", | |
| "features": dict(v.features) if v.features else {}, | |
| "assessment_status": "ASSESSED", | |
| } | |
| logger.info( | |
| "[ISR Verdict] track=%s relevant=%s satisfies=%s reason=%r features=%s", | |
| tid, v.mission_relevant, v.satisfies, v.reason, | |
| dict(v.features) if v.features else {}, | |
| ) | |
| return result | |
| except Exception: | |
| logger.exception("BAML ISR assessment failed") | |
| return {} | |