ASD / app /aggregator.py
Nx-Neuralon's picture
Upload 64 files
b6d0232 verified
from __future__ import annotations
from collections import defaultdict
from app.schemas import AgentResult, MergedEvent
def merge_overlapping_results(results: list[AgentResult], iou_gap: float = 2.0) -> list[MergedEvent]:
grouped = defaultdict(list)
for result in results:
for finding in result.findings:
grouped[finding.warning_type].append((result.agent_name, finding))
merged: list[MergedEvent] = []
for warning_type, items in grouped.items():
items.sort(key=lambda x: x[1].start_sec)
current = None
for agent_name, f in items:
if current is None:
current = {
"warning_type": warning_type,
"start_sec": f.start_sec,
"end_sec": f.end_sec,
"confidence_sum": f.confidence,
"count": 1,
"evidences": [f.evidence],
"sources": [agent_name],
"behavior_tags": set(f.behavior_tags),
"clinical_note": f.clinical_note or "",
}
continue
overlap = f.start_sec <= current["end_sec"] + iou_gap
if overlap:
current["end_sec"] = max(current["end_sec"], f.end_sec)
current["confidence_sum"] += f.confidence
current["count"] += 1
current["evidences"].append(f.evidence)
current["sources"].append(agent_name)
current["behavior_tags"].update(f.behavior_tags)
if len(f.clinical_note) > len(current["clinical_note"]):
current["clinical_note"] = f.clinical_note
else:
merged.append(
MergedEvent(
warning_type=current["warning_type"],
start_sec=current["start_sec"],
end_sec=current["end_sec"],
confidence=min(1.0, current["confidence_sum"] / current["count"]),
evidences=current["evidences"],
sources=current["sources"],
behavior_tags=sorted(current["behavior_tags"]),
clinical_note=current["clinical_note"],
)
)
current = {
"warning_type": warning_type,
"start_sec": f.start_sec,
"end_sec": f.end_sec,
"confidence_sum": f.confidence,
"count": 1,
"evidences": [f.evidence],
"sources": [agent_name],
"behavior_tags": set(f.behavior_tags),
"clinical_note": f.clinical_note or "",
}
if current is not None:
merged.append(
MergedEvent(
warning_type=current["warning_type"],
start_sec=current["start_sec"],
end_sec=current["end_sec"],
confidence=min(1.0, current["confidence_sum"] / current["count"]),
evidences=current["evidences"],
sources=current["sources"],
behavior_tags=sorted(current["behavior_tags"]),
clinical_note=current["clinical_note"],
)
)
merged.sort(key=lambda x: (x.start_sec, x.warning_type))
return merged