Spaces:
Sleeping
Sleeping
File size: 2,482 Bytes
b8de002 b3dc051 b8de002 423a536 b8de002 b412519 b8de002 a43c7ca b8de002 c08305c b8de002 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import torch
from torchvision.io import read_image
from torchvision import tv_tensors
from torchvision.transforms import v2 as T
from torchvision import transforms
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import io
def preprocess_image(image_path):
"""Returns the transformed image.
Args:
image_path (str): Image file path
Returns:
img (Tensor): Transformed image.
"""
# Read the image
img = read_image(image_path)
# Wrap image into torchvision tv_tensors
img = tv_tensors.Image(img)
# Apply transformations
transform = get_transform()
img = transform(img)
return [img]
def get_transform():
"""Returns the transformations to be applied to an image.
Returns:
transforms.Compose: The transformations to be applied to the images.
"""
transforms = []
# Convert the image to a float tensor and scale image
transforms.append(T.ToDtype(torch.float, scale=True))
# Convert the image to a pure tensor
transforms.append(T.ToPureTensor())
return T.Compose(transforms)
def draw_boxes_on_image(image, target, confidence_threshold=0.7):
"""
Draw bounding boxes directly on the image.
Args:
image (torch.Tensor): The image tensor (C, H, W).
target (dict): The target dictionary containing boxes, labels, and optionally scores.
confidence_threshold (float): Threshold for displaying predictions (if scores are provided).
Returns:
bytes: The image with boxes drawn, as a bytes object.
"""
colors = ['red', 'blue']
# Convert image from tensor to PIL Image
img = Image.fromarray((image.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
draw = ImageDraw.Draw(img)
boxes = target['boxes'].cpu()
labels = target['labels'].cpu()
scores = target.get('scores', torch.ones_like(labels)) # Use all 1s if no scores
for box, label, score in zip(boxes, labels, scores):
if score < confidence_threshold:
continue
xmin, ymin, xmax, ymax = box.tolist()
color = colors[label.item() - 1]
# Draw rectangle
draw.rectangle([xmin, ymin, xmax, ymax], outline=color, width=6)
# Draw label
draw.text((xmin, ymin), f"{label.item()}", fill=color)
# Convert PIL Image to bytes
buf = io.BytesIO()
img.save(buf, format='PNG')
buf.seek(0)
return buf
|