| import colorsys | |
| from PIL import Image, ImageDraw, ImageFont, ImageOps | |
| COCO_CLASSES = [ | |
| "person", | |
| "bicycle", | |
| "car", | |
| "motorcycle", | |
| "airplane", | |
| "bus", | |
| "train", | |
| "truck", | |
| "boat", | |
| "traffic light", | |
| "fire hydrant", | |
| "stop sign", | |
| "parking meter", | |
| "bench", | |
| "bird", | |
| "cat", | |
| "dog", | |
| "horse", | |
| "sheep", | |
| "cow", | |
| "elephant", | |
| "bear", | |
| "zebra", | |
| "giraffe", | |
| "backpack", | |
| "umbrella", | |
| "handbag", | |
| "tie", | |
| "suitcase", | |
| "frisbee", | |
| "skis", | |
| "snowboard", | |
| "sports ball", | |
| "kite", | |
| "baseball bat", | |
| "baseball glove", | |
| "skateboard", | |
| "surfboard", | |
| "tennis racket", | |
| "bottle", | |
| "wine glass", | |
| "cup", | |
| "fork", | |
| "knife", | |
| "spoon", | |
| "bowl", | |
| "banana", | |
| "apple", | |
| "sandwich", | |
| "orange", | |
| "broccoli", | |
| "carrot", | |
| "hot dog", | |
| "pizza", | |
| "donut", | |
| "cake", | |
| "chair", | |
| "couch", | |
| "potted plant", | |
| "bed", | |
| "dining table", | |
| "toilet", | |
| "tv", | |
| "laptop", | |
| "mouse", | |
| "remote", | |
| "keyboard", | |
| "cell phone", | |
| "microwave", | |
| "oven", | |
| "toaster", | |
| "sink", | |
| "refrigerator", | |
| "book", | |
| "clock", | |
| "vase", | |
| "scissors", | |
| "teddy bear", | |
| "hair drier", | |
| "toothbrush", | |
| ] | |
| def draw_overlay(im, threshold, labels, boxes, scores): | |
| draw = ImageDraw.Draw(im) | |
| font = ImageFont.load_default() | |
| for label, box, score in zip(labels[0], boxes[0], scores[0]): | |
| if score > threshold: | |
| color = get_color_by_label(label) | |
| draw.rectangle(box, outline=color, width=2) | |
| class_name = COCO_CLASSES[label] | |
| text = f"{class_name}: {score:.2f}" | |
| draw.text((box[0] + 4, box[1] + 2), text, fill=color, font=font) | |
| def get_color_by_label(label): | |
| hue = (((label * 17) % len(COCO_CLASSES))) / len(COCO_CLASSES) | |
| return tuple(round(i * 255) for i in colorsys.hsv_to_rgb(hue, 1.0, 1.0)) | |