| |
| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
| from typing import Dict, List, Optional |
|
|
| import torch |
| from torch import nn |
|
|
| from scripts.matcher import HungarianMatcher |
| from scripts.set_criterion_countgd import SetCriterion |
|
|
|
|
| @dataclass |
| class CriterionDetectConfig: |
| num_classes: int = 1 |
| class_cost: float = 1.0 |
| bbox_cost: float = 5.0 |
| giou_cost: float = 2.0 |
| lambda_cls: float = 1.0 |
| lambda_bbox: float = 5.0 |
| lambda_giou: float = 2.0 |
| eos_coef: float = 0.1 |
| focal_alpha: float = 0.25 |
|
|
|
|
| def build_criterion_detect( |
| tokenizer, |
| num_classes: int = 1, |
| class_cost: float = 1.0, |
| bbox_cost: float = 5.0, |
| giou_cost: float = 2.0, |
| lambda_cls: float = 1.0, |
| lambda_bbox: float = 5.0, |
| lambda_giou: float = 2.0, |
| eos_coef: float = 0.1, |
| focal_alpha: float = 0.25, |
| ) -> SetCriterion: |
| """ |
| Criterion detect theo đúng CountGD: |
| - HungarianMatcher trên (cls/token cost + L1 bbox + GIoU) |
| - SetCriterion dùng create_positive_map_exemplar để map nhãn ↔ tokens |
| """ |
| matcher = HungarianMatcher( |
| cost_class=class_cost, |
| cost_bbox=bbox_cost, |
| cost_giou=giou_cost, |
| ) |
|
|
| weight_dict = { |
| "loss_ce": float(lambda_cls), |
| "loss_bbox": float(lambda_bbox), |
| "loss_giou": float(lambda_giou), |
| } |
|
|
| losses = ["labels", "boxes", "cardinality"] |
|
|
| criterion = SetCriterion( |
| num_classes=num_classes, |
| matcher=matcher, |
| weight_dict=weight_dict, |
| eos_coef=eos_coef, |
| losses=losses, |
| focal_alpha=focal_alpha, |
| tokenizer=tokenizer, |
| ) |
| return criterion |
|
|