File size: 2,482 Bytes
b8de002
 
 
 
 
 
b3dc051
 
 
b8de002
 
 
 
 
 
 
 
 
 
 
 
423a536
b8de002
 
 
 
 
b412519
 
b8de002
a43c7ca
b8de002
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c08305c
b8de002
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92

import torch
from torchvision.io import read_image
from torchvision import tv_tensors
from torchvision.transforms import v2 as T
from torchvision import transforms
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import io

def preprocess_image(image_path):
    """Returns the transformed image.

    Args:
        image_path (str): Image file path

    Returns:
        img (Tensor): Transformed image.
    """

    # Read the image
    img = read_image(image_path)

    # Wrap image into torchvision tv_tensors
    img = tv_tensors.Image(img)

    # Apply transformations 
    transform = get_transform()
    img = transform(img)

    return [img]


def get_transform():
    """Returns the transformations to be applied to an image.

    Returns:
        transforms.Compose: The transformations to be applied to the images.
    """
    transforms = []

    # Convert the image to a float tensor and scale image
    transforms.append(T.ToDtype(torch.float, scale=True))

    # Convert the image to a pure tensor
    transforms.append(T.ToPureTensor())

    return T.Compose(transforms)


def draw_boxes_on_image(image, target, confidence_threshold=0.7):
    """
    Draw bounding boxes directly on the image.
    
    Args:
        image (torch.Tensor): The image tensor (C, H, W).
        target (dict): The target dictionary containing boxes, labels, and optionally scores.
        confidence_threshold (float): Threshold for displaying predictions (if scores are provided).
    
    Returns:
        bytes: The image with boxes drawn, as a bytes object.
    """
    colors = ['red', 'blue']
    
    # Convert image from tensor to PIL Image
    img = Image.fromarray((image.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
    draw = ImageDraw.Draw(img)
    
    boxes = target['boxes'].cpu()
    labels = target['labels'].cpu()
    scores = target.get('scores', torch.ones_like(labels))  # Use all 1s if no scores

    for box, label, score in zip(boxes, labels, scores):
        if score < confidence_threshold:
            continue
        
        xmin, ymin, xmax, ymax = box.tolist()
        color = colors[label.item() - 1]
        
        # Draw rectangle
        draw.rectangle([xmin, ymin, xmax, ymax], outline=color, width=6)
        
        # Draw label
        draw.text((xmin, ymin), f"{label.item()}", fill=color)

    # Convert PIL Image to bytes
    buf = io.BytesIO()
    img.save(buf, format='PNG')
    buf.seek(0)
    
    return buf