Spaces:
Running
Running
| import torch | |
| from typing import Optional | |
| from rstor.properties import LOSS_MSE | |
| def compute_loss( | |
| predic: torch.Tensor, | |
| target: torch.Tensor, | |
| mode: Optional[str] = LOSS_MSE | |
| ) -> torch.Tensor: | |
| """ | |
| Compute loss based on the predicted and true values. | |
| Args: | |
| predic (torch.Tensor): [N, C, H, W] predicted values | |
| target (torch.Tensor): [N, C, H, W] target values. | |
| mode (Optional[str], optional): mode of loss computation. | |
| Returns: | |
| torch.Tensor: The computed loss. | |
| """ | |
| assert mode in [LOSS_MSE], f"Mode {mode} not supported" | |
| if mode == LOSS_MSE: | |
| loss = torch.nn.functional.mse_loss(predic, target) | |
| return loss | |