Spaces:
Running
Running
| import torch | |
| from . import initialization as init | |
| from .hub_mixin import SMPHubMixin | |
| import torch.nn as nn | |
| class SegmentationModel(torch.nn.Module, SMPHubMixin): | |
| def initialize(self): | |
| # self.out = nn.Sequential( | |
| # nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False), | |
| # nn.BatchNorm2d(8), | |
| # nn.ReLU(inplace=True), | |
| # ) | |
| init.initialize_decoder(self.decoder) | |
| init.initialize_head(self.segmentation_head) | |
| if self.classification_head is not None: | |
| init.initialize_head(self.classification_head) | |
| def check_input_shape(self, x): | |
| h, w = x.shape[-2:] | |
| output_stride = self.encoder.output_stride | |
| if h % output_stride != 0 or w % output_stride != 0: | |
| new_h = ( | |
| (h // output_stride + 1) * output_stride | |
| if h % output_stride != 0 | |
| else h | |
| ) | |
| new_w = ( | |
| (w // output_stride + 1) * output_stride | |
| if w % output_stride != 0 | |
| else w | |
| ) | |
| raise RuntimeError( | |
| f"Wrong input shape height={h}, width={w}. Expected image height and width " | |
| f"divisible by {output_stride}. Consider pad your images to shape ({new_h}, {new_w})." | |
| ) | |
| def forward(self, x): | |
| """Sequentially pass `x` trough model`s encoder, decoder and heads""" | |
| self.check_input_shape(x) | |
| features = self.encoder(x) | |
| decoder_output = self.decoder(*features) | |
| decoder_output = self.segmentation_head(decoder_output) | |
| # | |
| # if self.classification_head is not None: | |
| # labels = self.classification_head(features[-1]) | |
| # return masks, labels | |
| return decoder_output | |
| def predict(self, x): | |
| """Inference method. Switch model to `eval` mode, call `.forward(x)` with `torch.no_grad()` | |
| Args: | |
| x: 4D torch tensor with shape (batch_size, channels, height, width) | |
| Return: | |
| prediction: 4D torch tensor with shape (batch_size, classes, height, width) | |
| """ | |
| if self.training: | |
| self.eval() | |
| x = self.forward(x) | |
| return x | |