import torch from torchvision.io import read_image from torchvision import tv_tensors from torchvision.transforms import v2 as T from torchvision import transforms import numpy as np from PIL import Image, ImageDraw, ImageFont import io def preprocess_image(image_path): """Returns the transformed image. Args: image_path (str): Image file path Returns: img (Tensor): Transformed image. """ # Read the image img = read_image(image_path) # Wrap image into torchvision tv_tensors img = tv_tensors.Image(img) # Apply transformations transform = get_transform() img = transform(img) return [img] def get_transform(): """Returns the transformations to be applied to an image. Returns: transforms.Compose: The transformations to be applied to the images. """ transforms = [] # Convert the image to a float tensor and scale image transforms.append(T.ToDtype(torch.float, scale=True)) # Convert the image to a pure tensor transforms.append(T.ToPureTensor()) return T.Compose(transforms) def draw_boxes_on_image(image, target, confidence_threshold=0.7): """ Draw bounding boxes directly on the image. Args: image (torch.Tensor): The image tensor (C, H, W). target (dict): The target dictionary containing boxes, labels, and optionally scores. confidence_threshold (float): Threshold for displaying predictions (if scores are provided). Returns: bytes: The image with boxes drawn, as a bytes object. """ colors = ['red', 'blue'] # Convert image from tensor to PIL Image img = Image.fromarray((image.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)) draw = ImageDraw.Draw(img) boxes = target['boxes'].cpu() labels = target['labels'].cpu() scores = target.get('scores', torch.ones_like(labels)) # Use all 1s if no scores for box, label, score in zip(boxes, labels, scores): if score < confidence_threshold: continue xmin, ymin, xmax, ymax = box.tolist() color = colors[label.item() - 1] # Draw rectangle draw.rectangle([xmin, ymin, xmax, ymax], outline=color, width=6) # Draw label draw.text((xmin, ymin), f"{label.item()}", fill=color) # Convert PIL Image to bytes buf = io.BytesIO() img.save(buf, format='PNG') buf.seek(0) return buf