from __future__ import annotations from collections import defaultdict from app.schemas import AgentResult, MergedEvent def merge_overlapping_results(results: list[AgentResult], iou_gap: float = 2.0) -> list[MergedEvent]: grouped = defaultdict(list) for result in results: for finding in result.findings: grouped[finding.warning_type].append((result.agent_name, finding)) merged: list[MergedEvent] = [] for warning_type, items in grouped.items(): items.sort(key=lambda x: x[1].start_sec) current = None for agent_name, f in items: if current is None: current = { "warning_type": warning_type, "start_sec": f.start_sec, "end_sec": f.end_sec, "confidence_sum": f.confidence, "count": 1, "evidences": [f.evidence], "sources": [agent_name], "behavior_tags": set(f.behavior_tags), "clinical_note": f.clinical_note or "", } continue overlap = f.start_sec <= current["end_sec"] + iou_gap if overlap: current["end_sec"] = max(current["end_sec"], f.end_sec) current["confidence_sum"] += f.confidence current["count"] += 1 current["evidences"].append(f.evidence) current["sources"].append(agent_name) current["behavior_tags"].update(f.behavior_tags) if len(f.clinical_note) > len(current["clinical_note"]): current["clinical_note"] = f.clinical_note else: merged.append( MergedEvent( warning_type=current["warning_type"], start_sec=current["start_sec"], end_sec=current["end_sec"], confidence=min(1.0, current["confidence_sum"] / current["count"]), evidences=current["evidences"], sources=current["sources"], behavior_tags=sorted(current["behavior_tags"]), clinical_note=current["clinical_note"], ) ) current = { "warning_type": warning_type, "start_sec": f.start_sec, "end_sec": f.end_sec, "confidence_sum": f.confidence, "count": 1, "evidences": [f.evidence], "sources": [agent_name], "behavior_tags": set(f.behavior_tags), "clinical_note": f.clinical_note or "", } if current is not None: merged.append( MergedEvent( warning_type=current["warning_type"], start_sec=current["start_sec"], end_sec=current["end_sec"], confidence=min(1.0, current["confidence_sum"] / current["count"]), evidences=current["evidences"], sources=current["sources"], behavior_tags=sorted(current["behavior_tags"]), clinical_note=current["clinical_note"], ) ) merged.sort(key=lambda x: (x.start_sec, x.warning_type)) return merged