| import torch | |
| from models.reveal_head import compute_task_metrics_from_fields | |
| def _metrics(access_level: float, disturbance_level: float) -> dict[str, torch.Tensor]: | |
| access_field = torch.full((1, 3, 4, 4), access_level) | |
| persistence_field = torch.full((1, 1, 4, 4), 0.6) | |
| disturbance_field = torch.full((1, 1, 4, 4), disturbance_level) | |
| reocclusion_field = torch.full((1, 1, 4, 4), 0.1) | |
| visibility_field = torch.full((1, 1, 4, 4), 0.6) | |
| clearance_field = torch.full((1, 2, 4, 4), access_level) | |
| support_stability_field = torch.full((1, 1, 4, 4), 0.7) | |
| uncertainty_field = torch.full((1, 1, 4, 4), 0.2) | |
| return compute_task_metrics_from_fields( | |
| access_field=access_field, | |
| persistence_field=persistence_field, | |
| disturbance_field=disturbance_field, | |
| reocclusion_field=reocclusion_field, | |
| visibility_field=visibility_field, | |
| clearance_field=clearance_field, | |
| support_stability_field=support_stability_field, | |
| uncertainty_field=uncertainty_field, | |
| ) | |
| def test_task_metric_monotonicity(): | |
| low_open = _metrics(access_level=0.1, disturbance_level=0.2) | |
| high_open = _metrics(access_level=0.9, disturbance_level=0.2) | |
| over_lift = _metrics(access_level=0.9, disturbance_level=0.8) | |
| assert high_open["mouth_aperture"] > low_open["mouth_aperture"] | |
| assert high_open["actor_feasibility_score"] > low_open["actor_feasibility_score"] | |
| assert over_lift["fold_preservation"] < high_open["fold_preservation"] | |