tree-counter / utils.py
Mawube's picture
Increase box widths
c08305c verified
import torch
from torchvision.io import read_image
from torchvision import tv_tensors
from torchvision.transforms import v2 as T
from torchvision import transforms
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import io
def preprocess_image(image_path):
"""Returns the transformed image.
Args:
image_path (str): Image file path
Returns:
img (Tensor): Transformed image.
"""
# Read the image
img = read_image(image_path)
# Wrap image into torchvision tv_tensors
img = tv_tensors.Image(img)
# Apply transformations
transform = get_transform()
img = transform(img)
return [img]
def get_transform():
"""Returns the transformations to be applied to an image.
Returns:
transforms.Compose: The transformations to be applied to the images.
"""
transforms = []
# Convert the image to a float tensor and scale image
transforms.append(T.ToDtype(torch.float, scale=True))
# Convert the image to a pure tensor
transforms.append(T.ToPureTensor())
return T.Compose(transforms)
def draw_boxes_on_image(image, target, confidence_threshold=0.7):
"""
Draw bounding boxes directly on the image.
Args:
image (torch.Tensor): The image tensor (C, H, W).
target (dict): The target dictionary containing boxes, labels, and optionally scores.
confidence_threshold (float): Threshold for displaying predictions (if scores are provided).
Returns:
bytes: The image with boxes drawn, as a bytes object.
"""
colors = ['red', 'blue']
# Convert image from tensor to PIL Image
img = Image.fromarray((image.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
draw = ImageDraw.Draw(img)
boxes = target['boxes'].cpu()
labels = target['labels'].cpu()
scores = target.get('scores', torch.ones_like(labels)) # Use all 1s if no scores
for box, label, score in zip(boxes, labels, scores):
if score < confidence_threshold:
continue
xmin, ymin, xmax, ymax = box.tolist()
color = colors[label.item() - 1]
# Draw rectangle
draw.rectangle([xmin, ymin, xmax, ymax], outline=color, width=6)
# Draw label
draw.text((xmin, ymin), f"{label.item()}", fill=color)
# Convert PIL Image to bytes
buf = io.BytesIO()
img.save(buf, format='PNG')
buf.seek(0)
return buf