| """Perspective fields decoder heads. |
| |
| Adapted from https://github.com/jinlinyi/PerspectiveFields |
| """ |
|
|
| import logging |
|
|
| from siclib.models import get_model |
| from siclib.models.base_model import BaseModel |
|
|
| logger = logging.getLogger(__name__) |
|
|
| |
| |
|
|
|
|
| class PerspectiveDecoder(BaseModel): |
| default_conf = { |
| "up_decoder": {"name": "decoders.up_decoder"}, |
| "latitude_decoder": {"name": "decoders.latitude_decoder"}, |
| } |
|
|
| required_data_keys = ["features"] |
|
|
| def _init(self, conf): |
| logger.debug(f"Initializing PerspectiveDecoder with config: {conf}") |
| self.use_up = conf.up_decoder is not None |
| self.use_latitude = conf.latitude_decoder is not None |
|
|
| if self.use_up: |
| self.up_head = get_model(conf.up_decoder.name)(conf.up_decoder) |
|
|
| if self.use_latitude: |
| self.latitude_head = get_model(conf.latitude_decoder.name)(conf.latitude_decoder) |
|
|
| def _forward(self, data): |
| out_up = self.up_head(data) if self.use_up else {} |
| out_lat = self.latitude_head(data) if self.use_latitude else {} |
| return out_up | out_lat |
|
|
| def loss(self, pred, data): |
| ref = data["up_field"] if self.use_up else data["latitude_field"] |
|
|
| total = ref.new_zeros(ref.shape[0]) |
| losses, metrics = {}, {} |
| if self.use_up: |
| up_losses, up_metrics = self.up_head.loss(pred, data) |
| losses |= up_losses |
| metrics |= up_metrics |
| total = total + losses.get("up_total", 0) |
|
|
| if self.use_latitude: |
| latitude_losses, latitude_metrics = self.latitude_head.loss(pred, data) |
| losses |= latitude_losses |
| metrics |= latitude_metrics |
| total = total + losses.get("latitude_total", 0) |
|
|
| losses["perspective_total"] = total |
| return losses, metrics |
|
|