| # task/segmentation_task.py | |
| from model.lightning_unet import LightningUNet | |
| import torch | |
| class SegmentationTask: | |
| def __init__(self, model_path: str): | |
| self.model = LightningUNet() | |
| self.model.load_state_dict(torch.load(model_path)) | |
| self.model.eval() | |
| def predict(self, image): | |
| with torch.no_grad(): | |
| output = self.model(image) | |
| # Aquí puedes agregar post-procesamiento si es necesario | |
| return output | |