# task/segmentation_task.py from model.lightning_unet import LightningUNet import torch class SegmentationTask: def __init__(self, model_path: str): self.model = LightningUNet() self.model.load_state_dict(torch.load(model_path)) self.model.eval() def predict(self, image): with torch.no_grad(): output = self.model(image) # AquĆ­ puedes agregar post-procesamiento si es necesario return output