|
|
from dataclasses import dataclass, field
|
|
|
from typing import Any, Dict, Optional
|
|
|
|
|
|
@dataclass
|
|
|
class LossConfiguration:
|
|
|
num_classes: int
|
|
|
|
|
|
xent_weight: float = 1.0
|
|
|
dice_weight: float = 1.0
|
|
|
focal_loss: bool = False
|
|
|
focal_loss_gamma: float = 2.0
|
|
|
requires_frustrum: bool = True
|
|
|
requires_flood_mask: bool = False
|
|
|
class_weights: Optional[Any] = None
|
|
|
label_smoothing: float = 0.1
|
|
|
|
|
|
@dataclass
|
|
|
class BackboneConfigurationBase:
|
|
|
pretrained: bool
|
|
|
frozen: bool
|
|
|
output_dim: bool
|
|
|
|
|
|
@dataclass
|
|
|
class DINOConfiguration(BackboneConfigurationBase):
|
|
|
pretrained: bool = True
|
|
|
frozen: bool = False
|
|
|
output_dim: int = 128
|
|
|
|
|
|
@dataclass
|
|
|
class ResNetConfiguration(BackboneConfigurationBase):
|
|
|
input_dim: int
|
|
|
encoder: str
|
|
|
remove_stride_from_first_conv: bool
|
|
|
num_downsample: Optional[int]
|
|
|
decoder_norm: str
|
|
|
do_average_pooling: bool
|
|
|
checkpointed: bool
|
|
|
|
|
|
@dataclass
|
|
|
class ImageEncoderConfiguration:
|
|
|
name: str
|
|
|
backbone: Any
|
|
|
|
|
|
@dataclass
|
|
|
class ModelConfiguration:
|
|
|
segmentation_head: Dict[str, Any]
|
|
|
image_encoder: ImageEncoderConfiguration
|
|
|
|
|
|
name: str
|
|
|
num_classes: int
|
|
|
latent_dim: int
|
|
|
z_max: int
|
|
|
x_max: int
|
|
|
|
|
|
pixel_per_meter: int
|
|
|
num_scale_bins: int
|
|
|
|
|
|
loss: LossConfiguration
|
|
|
|
|
|
scale_range: list[int] = field(default_factory=lambda: [0, 9])
|
|
|
z_min: Optional[int] = None |