detection-heads / heads /__init__.py
phanerozoic's picture
dim20: 3.9 mAP from 22K params. 768->20->80 bottleneck, SVD-initialized from pruned prototypes.
d5bbf99 verified
"""Detection head registry. Maps names to classes."""
from .baseline_fcos.head import BaselineFCOS
from .slim_fcos.head import SlimFCOS
from .centernet.head import CenterNet
from .hook_fcos.head import HookFCOS
from .sparse_query.head import SparseQueries
from .prototype.head import PrototypeContrastive
from .depth_fusion.head import DepthFusion
from .wavelet.head import Wavelet
from .adaptive_query.head import AdaptiveQuery
from .patch_assembly.head import PatchAssembly
from .mutual_attention.head import MutualAttention
from .feature_graph.head import FeatureGraph
from .scale_classify.head import ScaleClassify
from .relational_corners.head import RelationalCorners
from .cascade_pool.head import CascadePool
from .cofiber_linear.head import CofiberLinear
from .threshold_prototype.head import ThresholdPrototype
from .cofiber_threshold.head import CofiberThreshold
from .cofiber_threshold.box32_92k.head import CofiberThresholdV2
from .cofiber_centernet.head import CofiberCenterNet
from .cofiber_5scale.head import Cofiber5Scale
from .cofiber_adaptive.head import CofiberAdaptive
from .optimal_transport.head import OptimalTransport
from .tropical.head import TropicalDetection
from .compression.head import CompressionDetection
from .curvature.head import CurvatureDetection
from .cofiber_threshold.dim20_20k.head import CofiberThresholdDim20
REGISTRY = {
"baseline_fcos": BaselineFCOS,
"slim_fcos": SlimFCOS,
"centernet": CenterNet,
"hook_fcos": HookFCOS,
"sparse_query": SparseQueries,
"prototype": PrototypeContrastive,
"depth_fusion": DepthFusion,
"wavelet": Wavelet,
"adaptive_query": AdaptiveQuery,
"patch_assembly": PatchAssembly,
"mutual_attention": MutualAttention,
"feature_graph": FeatureGraph,
"scale_classify": ScaleClassify,
"relational_corners": RelationalCorners,
"cascade_pool": CascadePool,
"cofiber_linear": CofiberLinear,
"threshold_prototype": ThresholdPrototype,
"cofiber_threshold": CofiberThreshold,
"cofiber_centernet": CofiberCenterNet,
"cofiber_5scale": Cofiber5Scale,
"cofiber_adaptive": CofiberAdaptive,
"optimal_transport": OptimalTransport,
"tropical": TropicalDetection,
"compression": CompressionDetection,
"curvature": CurvatureDetection,
"cofiber_threshold_box32": CofiberThresholdV2,
"cofiber_threshold_dim20": CofiberThresholdDim20,
}
ALL_NAMES = list(REGISTRY.keys())
def get_head(name):
if name not in REGISTRY:
raise ValueError(f"Unknown head: {name}. Available: {ALL_NAMES}")
return REGISTRY[name]()