| from transformers import UperNetForSemanticSegmentation | |
| import torch | |
| def load_model(pretrained_model: str, num_classes: int, device: torch.device) -> torch.nn.Module: | |
| """ | |
| Loads the UperNet model with a custom number of classes and sends it to the right device. | |
| """ | |
| model = UperNetForSemanticSegmentation.from_pretrained( | |
| pretrained_model, | |
| num_labels=num_classes, | |
| ignore_mismatched_sizes=True | |
| ) | |
| return model.to(device) | |