|
|
import json |
|
|
import re |
|
|
import numpy as np |
|
|
from typing import Tuple |
|
|
from collections import defaultdict |
|
|
|
|
|
''' |
|
|
dict_keys(['struc_info', 'metadata', 'qa_type', 'question', 'answer', 'gnd']) |
|
|
|
|
|
''' |
|
|
def extract_boxes(raw_output): |
|
|
print("="*50) |
|
|
print(raw_output) |
|
|
''' |
|
|
for raw output |
|
|
''' |
|
|
pattern = re.compile(r"\[\[([\d\s,]+)\]\]") |
|
|
matches = pattern.findall(raw_output) |
|
|
|
|
|
boxes = [] |
|
|
for match in matches: |
|
|
try: |
|
|
box = [int(x.strip()) for x in match.split(',')] |
|
|
if len(box) == 4: |
|
|
boxes.append(box) |
|
|
except: |
|
|
continue |
|
|
return boxes |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def post_process_pred(raw_output): |
|
|
""" |
|
|
Parses STG-style prediction text into a dictionary {time_key: [x1, y1, x2, y2]}. |
|
|
|
|
|
Supports float second keys like '8.0 seconds: [x1, y1, x2, y2]' |
|
|
|
|
|
If parsing fails, fall back to extract_boxes(). |
|
|
""" |
|
|
pattern = r"(\d+(?:\.\d+)?)\s+seconds:\s*\[([^\]]+)\]" |
|
|
matches = re.findall(pattern, raw_output) |
|
|
|
|
|
if not matches: |
|
|
|
|
|
return extract_boxes(raw_output) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
parsed_prediction = {} |
|
|
last_valid_box = None |
|
|
for k, v in matches: |
|
|
try: |
|
|
nums = [] |
|
|
for num in v.split(','): |
|
|
num_clean = num.strip().lstrip('[').rstrip(']') |
|
|
nums.append(float(num_clean)) |
|
|
if len(nums) != 4: |
|
|
raise ValueError("Box should have 4 values.") |
|
|
parsed_prediction[str(float(k))] = nums |
|
|
last_valid_box = nums |
|
|
except ValueError: |
|
|
print(f"[Outlier] Failed to parse entry at time {k}: {v}") |
|
|
print(f"Raw output line: {k} seconds: [{v}]") |
|
|
print("---") |
|
|
if last_valid_box is not None: |
|
|
parsed_prediction[str(float(k))] = last_valid_box |
|
|
else: |
|
|
print(f"[Warning] No valid box available to copy for time {k}") |
|
|
|
|
|
return parsed_prediction |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return parsed_prediction |
|
|
|
|
|
def is_valid_box(box): |
|
|
return isinstance(box, list) and len(box) == 4 and all(isinstance(x, (int, float)) for x in box) |
|
|
|
|
|
|
|
|
def np_box_area(boxes: np.array) -> np.array: |
|
|
""" |
|
|
Computes the area of a set of bounding boxes, which are specified by its |
|
|
(x1, y1, x2, y2) coordinates. |
|
|
|
|
|
Args: |
|
|
boxes (Tensor[N, 4]): boxes for which the area will be computed. They |
|
|
are expected to be in (x1, y1, x2, y2) format with |
|
|
``0 <= x1 < x2`` and ``0 <= y1 < y2``. |
|
|
|
|
|
Returns: |
|
|
area (Tensor[N]): area for each box |
|
|
""" |
|
|
assert boxes.ndim == 2 and boxes.shape[-1] == 4 |
|
|
return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) |
|
|
|
|
|
|
|
|
def _box_inter_union(boxes1: np.array, boxes2: np.array) -> Tuple[np.array, np.array]: |
|
|
area1 = np_box_area(boxes1) |
|
|
area2 = np_box_area(boxes2) |
|
|
|
|
|
lt = np.maximum(boxes1[:, None, :2], boxes2[:, :2]) |
|
|
rb = np.minimum(boxes1[:, None, 2:], boxes2[:, 2:]) |
|
|
|
|
|
wh = (rb - lt).clip(min=0) |
|
|
inter = wh[:, :, 0] * wh[:, :, 1] |
|
|
|
|
|
union = area1[:, None] + area2 - inter |
|
|
|
|
|
return inter, union |
|
|
|
|
|
|
|
|
def np_box_iou(boxes1: np.array, boxes2: np.array) -> np.array: |
|
|
""" |
|
|
Return intersection-over-union (Jaccard index) of boxes. |
|
|
|
|
|
Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with |
|
|
``0 <= x1 < x2`` and ``0 <= y1 < y2``. |
|
|
|
|
|
Args: |
|
|
boxes1 (Tensor[N, 4]) |
|
|
boxes2 (Tensor[M, 4]) |
|
|
|
|
|
Returns: |
|
|
iou (Tensor[N, M]): the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2 |
|
|
""" |
|
|
inter, union = _box_inter_union(boxes1, boxes2) |
|
|
iou = inter / union |
|
|
return iou |
|
|
|
|
|
def validate_prediction_and_gt(pred_dict, gt_dict): |
|
|
pred_keys = set(pred_dict.keys()) |
|
|
gt_keys = set(gt_dict.keys()) |
|
|
|
|
|
if pred_keys != gt_keys: |
|
|
missing_in_pred = gt_keys - pred_keys |
|
|
missing_in_gt = pred_keys - gt_keys |
|
|
print("Key mismatch:") |
|
|
if missing_in_pred: |
|
|
print(" - Missing in prediction:", missing_in_pred) |
|
|
if missing_in_gt: |
|
|
print(" - Missing in ground truth:", missing_in_gt) |
|
|
return False |
|
|
|
|
|
for k in pred_keys: |
|
|
if not is_valid_box(pred_dict[k]): |
|
|
print(f"Invalid prediction box for key {k}: {pred_dict[k]}") |
|
|
return False |
|
|
if not is_valid_box(gt_dict[k]): |
|
|
print(f"Invalid ground truth box for key {k}: {gt_dict[k]}") |
|
|
return False |
|
|
|
|
|
|
|
|
return True |
|
|
|
|
|
|
|
|
def compute_iou_batch(boxes1, boxes2): |
|
|
""" |
|
|
boxes1, boxes2: (N, 4) arrays where each row is [x1, y1, x2, y2] |
|
|
""" |
|
|
|
|
|
xA = np.maximum(boxes1[:, 0], boxes2[:, 0]) |
|
|
yA = np.maximum(boxes1[:, 1], boxes2[:, 1]) |
|
|
xB = np.minimum(boxes1[:, 2], boxes2[:, 2]) |
|
|
yB = np.minimum(boxes1[:, 3], boxes2[:, 3]) |
|
|
|
|
|
inter_w = np.clip(xB - xA, 0, None) |
|
|
inter_h = np.clip(yB - yA, 0, None) |
|
|
inter_area = inter_w * inter_h |
|
|
|
|
|
area1 = (boxes1[:, 2] - boxes1[:, 0]) * (boxes1[:, 3] - boxes1[:, 1]) |
|
|
area2 = (boxes2[:, 2] - boxes2[:, 0]) * (boxes2[:, 3] - boxes2[:, 1]) |
|
|
union_area = area1 + area2 - inter_area |
|
|
|
|
|
iou = inter_area / np.clip(union_area, 1e-6, None) |
|
|
return iou |
|
|
|
|
|
|
|
|
def group_records_by_dataset(data): |
|
|
"""Group STG records by dataset for per-dataset evaluation.""" |
|
|
dataset_groups = defaultdict(list) |
|
|
|
|
|
for key, record in data.items(): |
|
|
qa_type = record.get('qa_type', '') |
|
|
if 'stg' not in qa_type.lower(): |
|
|
continue |
|
|
|
|
|
|
|
|
dataset = record.get('data_source', record.get('dataset', record.get('dataset_name', record.get('metadata', {}).get('dataset', 'Unknown')))) |
|
|
video_id = record.get('video_id', record.get('metadata', {}).get('video_id', '')) |
|
|
|
|
|
if dataset == 'Unknown' and video_id: |
|
|
video_id_lower = str(video_id).lower() |
|
|
if len(video_id) == 11 and any(c.isalpha() for c in video_id): |
|
|
dataset = "AVOS" |
|
|
elif "_part" in video_id_lower: |
|
|
dataset = "CoPESD" |
|
|
elif "video" in video_id_lower: |
|
|
dataset = "CholecT50" |
|
|
|
|
|
dataset_groups[dataset].append(record) |
|
|
|
|
|
return dict(dataset_groups) |
|
|
|
|
|
|
|
|
def evaluate_dataset_stg(dataset_name, records): |
|
|
"""Evaluate STG for a specific dataset.""" |
|
|
print(f"\nEvaluating {dataset_name} ({len(records)} records)...") |
|
|
|
|
|
results_by_fps = defaultdict(list) |
|
|
all_ious = [] |
|
|
|
|
|
for record in records: |
|
|
fps = record.get('fps', record.get('metadata', {}).get('fps', 1.0)) |
|
|
if isinstance(fps, str): |
|
|
fps = float(fps) |
|
|
|
|
|
|
|
|
raw_answer = record.get('answer', '') |
|
|
processed_pred = post_process_pred(raw_answer) |
|
|
|
|
|
|
|
|
struc_info = record.get('struc_info', {}) |
|
|
if isinstance(struc_info, list) and len(struc_info) > 0: |
|
|
struc_item = struc_info[0] |
|
|
if isinstance(struc_item, dict) and 'bbox_dict' in struc_item: |
|
|
gt_dict = struc_item['bbox_dict'] |
|
|
else: |
|
|
gt_dict = struc_item |
|
|
elif isinstance(struc_info, dict): |
|
|
if 'bbox_dict' in struc_info: |
|
|
gt_dict = struc_info['bbox_dict'] |
|
|
else: |
|
|
gt_dict = struc_info |
|
|
else: |
|
|
gt_dict = {} |
|
|
|
|
|
|
|
|
if isinstance(processed_pred, list): |
|
|
key_list = list(gt_dict.keys()) |
|
|
processed_pred = {key: box for key, box in zip(key_list[:len(processed_pred)], processed_pred)} |
|
|
|
|
|
|
|
|
pred_boxes = [] |
|
|
gt_boxes = [] |
|
|
|
|
|
for i, key in enumerate(gt_dict.keys()): |
|
|
gt_boxes.append(gt_dict[key]) |
|
|
key_str = f"{float(key):.1f}" |
|
|
pred_box = processed_pred.get(key_str, [0, 0, 0, 0]) if isinstance(processed_pred, dict) else [0, 0, 0, 0] |
|
|
if pred_box == [0, 0, 0, 0] and i > 0: |
|
|
pred_box = pred_boxes[i - 1] |
|
|
pred_boxes.append(pred_box) |
|
|
|
|
|
|
|
|
valid_pred_boxes = [] |
|
|
valid_gt_boxes = [] |
|
|
for pred_box, gt_box in zip(pred_boxes, gt_boxes): |
|
|
if is_valid_box(pred_box) and is_valid_box(gt_box): |
|
|
valid_pred_boxes.append(pred_box) |
|
|
valid_gt_boxes.append(gt_box) |
|
|
|
|
|
if valid_pred_boxes and valid_gt_boxes: |
|
|
pred_boxes_array = np.array(valid_pred_boxes) |
|
|
gt_boxes_array = np.array(valid_gt_boxes) |
|
|
iou = compute_iou_batch(pred_boxes_array, gt_boxes_array) |
|
|
|
|
|
if len(iou) > 0: |
|
|
mean_iou = iou.mean() |
|
|
results_by_fps[fps].append(mean_iou) |
|
|
all_ious.append(mean_iou) |
|
|
|
|
|
|
|
|
aggregated = {} |
|
|
for fps, iou_list in results_by_fps.items(): |
|
|
if iou_list: |
|
|
aggregated[f'fps_{fps}'] = { |
|
|
'iou@0.3': np.mean([1 if x >= 0.3 else 0 for x in iou_list]), |
|
|
'iou@0.5': np.mean([1 if x >= 0.5 else 0 for x in iou_list]), |
|
|
'iou@0.7': np.mean([1 if x >= 0.7 else 0 for x in iou_list]), |
|
|
'mIoU': np.mean(iou_list), |
|
|
'count': len(iou_list) |
|
|
} |
|
|
|
|
|
|
|
|
if all_ious: |
|
|
overall_mean_iou = np.mean(all_ious) |
|
|
aggregated['overall'] = { |
|
|
'mean_iou': overall_mean_iou, |
|
|
'valid_records': len(all_ious), |
|
|
'total_records': len(records) |
|
|
} |
|
|
|
|
|
return aggregated |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main evaluation function for STG.""" |
|
|
import sys |
|
|
|
|
|
if len(sys.argv) < 2: |
|
|
print("Usage: python eval_stg.py <results_json_file>") |
|
|
print("Example: python eval_stg.py results/model_results.json") |
|
|
sys.exit(1) |
|
|
|
|
|
output_file = sys.argv[1] |
|
|
print(f"Loading results from: {output_file}") |
|
|
|
|
|
with open(output_file, "r") as f: |
|
|
infer_output = json.load(f) |
|
|
|
|
|
dataset_records = group_records_by_dataset(infer_output) |
|
|
|
|
|
print(f"\nFound datasets: {list(dataset_records.keys())}") |
|
|
for dataset, records in dataset_records.items(): |
|
|
print(f" {dataset}: {len(records)} STG records") |
|
|
|
|
|
if not any(dataset_records.values()): |
|
|
print("No STG records found!") |
|
|
return |
|
|
|
|
|
all_results = {} |
|
|
for dataset_name, records in dataset_records.items(): |
|
|
if records: |
|
|
results = evaluate_dataset_stg(dataset_name, records) |
|
|
all_results[dataset_name] = results |
|
|
|
|
|
print(f"\n{'='*80}") |
|
|
print("STG EVALUATION SUMMARY") |
|
|
print(f"{'='*80}") |
|
|
|
|
|
for dataset_name, fps_results in all_results.items(): |
|
|
if fps_results: |
|
|
print(f"\n{dataset_name}:") |
|
|
for fps_key, metrics in sorted(fps_results.items()): |
|
|
print(f" {fps_key}:") |
|
|
for metric_name, value in metrics.items(): |
|
|
if metric_name != 'count': |
|
|
print(f" {metric_name}: {value:.4f}") |
|
|
else: |
|
|
print(f" samples: {value}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|