File size: 468 Bytes
d218927 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 |
# task/segmentation_task.py
from model.lightning_unet import LightningUNet
import torch
class SegmentationTask:
def __init__(self, model_path: str):
self.model = LightningUNet()
self.model.load_state_dict(torch.load(model_path))
self.model.eval()
def predict(self, image):
with torch.no_grad():
output = self.model(image)
# Aquí puedes agregar post-procesamiento si es necesario
return output
|