|
|
import torch
|
|
|
|
|
|
from .metrics import PixelAccuracy, MeanObservableIOU, MeanUnobservableIOU, ObservableIOU, UnobservableIOU, mAP
|
|
|
|
|
|
from .loss import EnhancedLoss
|
|
|
|
|
|
from .segmentation_head import SegmentationHead
|
|
|
|
|
|
from . import get_model
|
|
|
from .base import BaseModel
|
|
|
from .bev_projection import CartesianProjection, PolarProjectionDepth
|
|
|
from .schema import ModelConfiguration
|
|
|
|
|
|
class MapPerceptionNet(BaseModel):
|
|
|
|
|
|
def _init(self, conf: ModelConfiguration):
|
|
|
self.image_encoder = get_model(
|
|
|
conf.image_encoder.name
|
|
|
)(conf.image_encoder.backbone)
|
|
|
|
|
|
self.decoder = SegmentationHead(
|
|
|
in_channels=conf.latent_dim, n_classes=conf.num_classes)
|
|
|
|
|
|
ppm = conf.pixel_per_meter
|
|
|
self.projection_polar = PolarProjectionDepth(
|
|
|
conf.z_max,
|
|
|
ppm,
|
|
|
conf.scale_range,
|
|
|
conf.z_min,
|
|
|
)
|
|
|
self.projection_bev = CartesianProjection(
|
|
|
conf.z_max, conf.x_max, ppm, conf.z_min
|
|
|
)
|
|
|
|
|
|
self.scale_classifier = torch.nn.Linear(
|
|
|
conf.latent_dim, conf.num_scale_bins
|
|
|
)
|
|
|
|
|
|
self.num_classes = conf.num_classes
|
|
|
|
|
|
self.loss_fn = EnhancedLoss(conf.loss)
|
|
|
|
|
|
def _forward(self, data):
|
|
|
f_image, camera = self.image_encoder(data)
|
|
|
|
|
|
scales = self.scale_classifier(
|
|
|
f_image.moveaxis(1, -1))
|
|
|
f_polar = self.projection_polar(f_image, scales, camera)
|
|
|
|
|
|
|
|
|
f_bev, valid_bev, _ = self.projection_bev(
|
|
|
f_polar.float(), None, camera.float()
|
|
|
)
|
|
|
|
|
|
output = self.decoder(f_bev[..., :-1])
|
|
|
|
|
|
probs = torch.nn.functional.sigmoid(output)
|
|
|
|
|
|
return {
|
|
|
"output": probs,
|
|
|
"logits": output,
|
|
|
"scales": scales,
|
|
|
"features_image": f_image,
|
|
|
"features_bev": f_bev,
|
|
|
"valid_bev": valid_bev.squeeze(1),
|
|
|
}
|
|
|
|
|
|
def loss(self, pred, data):
|
|
|
loss = self.loss_fn(pred, data)
|
|
|
return loss
|
|
|
|
|
|
def metrics(self):
|
|
|
m = {
|
|
|
"pix_acc": PixelAccuracy(),
|
|
|
"map": mAP(self.num_classes),
|
|
|
"miou_observable": MeanObservableIOU(self.num_classes),
|
|
|
"miou_non_observable": MeanUnobservableIOU(self.num_classes),
|
|
|
}
|
|
|
m.update(
|
|
|
{
|
|
|
f"IoU_observable_class_{i}": ObservableIOU(i, num_classes=self.num_classes)
|
|
|
for i in range(self.num_classes)
|
|
|
}
|
|
|
)
|
|
|
m.update(
|
|
|
{
|
|
|
f"IoU_non_observable_{i}": UnobservableIOU(i, num_classes=self.num_classes)
|
|
|
for i in range(self.num_classes)
|
|
|
}
|
|
|
)
|
|
|
return m
|
|
|
|