File size: 3,928 Bytes
043bbad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from ultralytics.utils.plotting import Annotator, colors
import numpy as np
import torch
from copy import deepcopy

def custom_plot(
    self,
    conf: bool = True,
    line_width: float | None = None,
    font_size: float | None = None,
    font: str = "Arial.ttf",
    pil: bool = False,
    img: np.ndarray | None = None,
    im_gpu: torch.Tensor | None = None,
    kpt_radius: int = 5,
    kpt_line: bool = True,
    labels: bool = True,
    boxes: bool = True,
    masks: bool = True,
    probs: bool = True,
    show: bool = False,
    save: bool = False,
    filename: str | None = None,
    color_mode: str = "class",
    txt_color: tuple[int, int, int] = (255, 255, 255),
    barcode_texts: list[str|None] = None,
) -> np.ndarray:
    """Plot detection results on an input BGR image.

    Args:
        conf (bool): Whether to plot detection confidence scores.
        line_width (float | None): Line width of bounding boxes. If None, scaled to image size.
        font_size (float | None): Font size for text. If None, scaled to image size.
        font (str): Font to use for text.
        pil (bool): Whether to return the image as a PIL Image.
        img (np.ndarray | None): Image to plot on. If None, uses original image.
        im_gpu (torch.Tensor | None): Normalized image on GPU for faster mask plotting.
        kpt_radius (int): Radius of drawn keypoints.
        kpt_line (bool): Whether to draw lines connecting keypoints.
        labels (bool): Whether to plot labels of bounding boxes.
        boxes (bool): Whether to plot bounding boxes.
        masks (bool): Whether to plot masks.
        probs (bool): Whether to plot classification probabilities.
        show (bool): Whether to display the annotated image.
        save (bool): Whether to save the annotated image.
        filename (str | None): Filename to save image if save is True.
        color_mode (str): Specify the color mode, e.g., 'instance' or 'class'.
        txt_color (tuple[int, int, int]): Text color in BGR format for classification output.

    Returns:
        (np.ndarray | PIL.Image.Image): Annotated image as a NumPy array (BGR) or PIL image (RGB) if `pil=True`.

    Examples:
        >>> results = model("image.jpg")
        >>> for result in results:
        >>>     im = result.plot()
        >>>     im.show()
    """
    assert color_mode in {"instance", "class"}, f"Expected color_mode='instance' or 'class', not {color_mode}."
    if img is None and isinstance(self.orig_img, torch.Tensor):
        img = (self.orig_img[0].detach().permute(1, 2, 0).contiguous() * 255).byte().cpu().numpy()

    names = self.names
    is_obb = self.obb is not None
    pred_boxes, show_boxes = self.obb if is_obb else self.boxes, boxes
    pred_masks, show_masks = self.masks, masks
    pred_probs, show_probs = self.probs, probs
    annotator = Annotator(
        deepcopy(self.orig_img if img is None else img),
        line_width,
        font_size,
        font,
        pil or (pred_probs is not None and show_probs),  # Classify tasks default to pil=True
        example=names,
    )
    # Plot Detect results
    if pred_boxes is not None and show_boxes:
        for i, d in enumerate(reversed(pred_boxes)):
            c, d_conf, id = int(d.cls), float(d.conf) if conf else None, int(d.id.item()) if d.is_track else None
            name = ("" if id is None else f"id:{id} ") + names[c]
            if barcode_texts is None:
                label = (f"{name} {d_conf:.2f}" if conf else name) if labels else None
            else:
                label = barcode_texts[len(pred_boxes) - i - 1]
                # label = f'{len(pred_boxes) - i - 1} {label if label else ""}'
            box = d.xyxyxyxy.squeeze() if is_obb else d.xyxy.squeeze()
            annotator.box_label(
                box,
                label,
                color=colors(0 if not label else 6),
            )
    return annotator.result(pil)