blitzkrieg0000's picture
Update Lib/Core.py
f7a2b96 verified
import os
import sys
sys.path.append(os.getcwd())
import cv2
import numpy as np
import torch
from matplotlib import pyplot as plt
from ultralytics import YOLO
from Lib.Consts import LABELS, COLOR_MAP, COLOR_MAP_RGB
from Tool.Core import DownloadHFModel
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Download Model
REPO_ID = "blitzkrieg0000/yolov9_pole-cable-detection"
MODEL_FILE = "yolov9c-cable-seg.pt"
DownloadHFModel(REPO_ID, MODEL_FILE)
class CablePoleSegmentation():
def __init__(self, model_path=None, retina_mask=False):
if not model_path:
model_path = "./Weight/yolov9c-cable-seg.pt"
self._RetinaMask=retina_mask
self.Model = None
self.PrepareModel(model_path)
def PrepareModel(self, model_path):
self.Model = YOLO(model_path)
self.Model.fuse()
def ScaleMasks(self, masks: torch.Tensor, shape: tuple) -> torch.Tensor:
masks = masks.unsqueeze(0)
interpolatedMask:torch.Tensor = torch.nn.functional.interpolate(masks, shape, mode="nearest")
interpolatedMask = interpolatedMask.squeeze(0)
return interpolatedMask
def ParseResults(self, results, threshold=0.5, scale_masks=True):
batches = []
SCORES = torch.Tensor([]).to(DEVICE)
CLASSES = torch.Tensor([]).to(DEVICE)
MASKS = torch.Tensor([]).to(DEVICE)
BOXES = torch.Tensor([]).to(DEVICE)
with torch.no_grad():
for result in results:
original_shape = result.orig_shape
_scores = result.boxes.conf # 7
_classes = result.boxes.cls # 7
_masks = result.masks.data # 7, 480, 640
_boxes = result.boxes.xyxy # 7, 4
# Threshold Filter
conditions = _scores > threshold
SCORES = torch.cat((SCORES, _scores[conditions]), dim=0)
CLASSES = torch.cat((CLASSES, _classes[conditions]), dim=0)
BOXES = torch.cat((BOXES, _boxes[conditions]), dim=0)
mask = _masks[conditions]
if mask.shape[0] == 0:
continue
if scale_masks:
mask = self.ScaleMasks(mask, original_shape[:2])
MASKS = torch.cat((MASKS, mask), dim=0)
batches += [(SCORES, CLASSES, MASKS, BOXES)]
return batches
def DrawResults(self, image, scores: torch.Tensor, classes: torch.Tensor, masks: torch.Tensor, boxes: torch.Tensor, labels:dict=LABELS, class_filter:list=None):
_image = np.array(image).copy()
_image = cv2.cvtColor(_image, cv2.COLOR_BGR2RGB)
maskCanvas = np.zeros_like(_image)
with torch.no_grad():
scores = scores.cpu().numpy()
classes = classes.cpu().numpy().astype(np.int32)
masks = masks.cpu().numpy()
boxes = boxes.cpu().numpy()
colors = list(COLOR_MAP_RGB.values())
for score, cls, mask, box in zip(scores, classes, masks, boxes):
label = labels[cls]
_color = colors[cls]
if class_filter and cls not in class_filter:
continue
box = box.astype(np.int32)
mask = (cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)*_color).astype(np.uint8)
maskCanvas = cv2.addWeighted(maskCanvas, 1.0, mask, 1.0, 0)
maskCanvas = cv2.rectangle(maskCanvas, (box[0], box[1]), (box[2], box[3]), color=_color, thickness=5) # Red color for bounding box
maskCanvas = cv2.putText(maskCanvas, f"{label} : {score:.2f}", (box[0], box[1] - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color=_color, thickness=2)
canvas = cv2.addWeighted(_image, 1.0, maskCanvas.astype(np.uint8), 0.5, 0)
return canvas, maskCanvas
def Process(self, image, model_threshold=0.6, overall_threshold=0.6, iou=0.7, class_filter:list=None):
with torch.no_grad():
results = self.Model(
image,
save=False,
show_boxes=False,
project="./inference/",
conf=model_threshold,
iou=iou,
retina_masks=False,
stream=True,
classes=class_filter,
device=DEVICE
)
batches = self.ParseResults(results, threshold=overall_threshold, scale_masks=True)
return batches
if "__main__" == __name__:
test = "data/DJI_20240905091530_0003_W.JPG"
image = cv2.imread(test)
model = CablePoleSegmentation(retina_mask=False)
batches = model.Process(image)
if len(batches) == 0:
exit()
scores, classes, masks, boxes = batches[0] # First
canvas, mask = model.DrawResults(image, scores, classes, masks, boxes, class_filter=None)
print(canvas.shape)
#! Plot
fig, axs = plt.subplots(1, 3, figsize=(27, 15))
axs[0].imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
axs[0].set_title("Orijinal Görüntü")
axs[1].imshow(mask)
axs[1].set_title("Segmentasyon Maskesi")
axs[2].imshow(canvas)
axs[2].set_title("Sonuç")
plt.tight_layout()
plt.show()