xDesCO / segmentation /model.py
Nguyễn Thành Đạt
update code
036e7c4
from ultralytics import YOLO
import cv2
import torch
BEST_WEIGHT = './segmentation/weights/oai_s_best4.pt'
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def create_model(weight_path=None):
if weight_path:
return YOLO(weight_path)
else:
return YOLO(BEST_WEIGHT)
class Segmenter():
def __init__(self, weight_path=None):
self.model = create_model(weight_path).to(DEVICE)
def segment(self, img):
"""
input: image (H, W, C)
output: mask (H, W) with femur is 1 and tibia is 2
"""
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
eimg = cv2.equalizeHist(img)
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
eimg = clahe.apply(eimg)
eimg = cv2.cvtColor(eimg, cv2.COLOR_GRAY2RGB)
res = self.model(eimg, verbose=False)
mask = res[0].masks.data[0] * (res[0].boxes.cls[0] + 1) + res[0].masks.data[1] * (res[0].boxes.cls[1] + 1)
mask = mask.cpu().numpy()
return mask