File size: 468 Bytes
d218927
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# task/segmentation_task.py
from model.lightning_unet import LightningUNet
import torch

class SegmentationTask:
    def __init__(self, model_path: str):
        self.model = LightningUNet()
        self.model.load_state_dict(torch.load(model_path))
        self.model.eval()

    def predict(self, image):
        with torch.no_grad():
            output = self.model(image)
            # Aquí puedes agregar post-procesamiento si es necesario
        return output