Spaces:
Sleeping
Sleeping
| import torch | |
| from torchvision.io import read_image | |
| from torchvision import tv_tensors | |
| from torchvision.transforms import v2 as T | |
| from torchvision import transforms | |
| import numpy as np | |
| from PIL import Image, ImageDraw, ImageFont | |
| import io | |
| def preprocess_image(image_path): | |
| """Returns the transformed image. | |
| Args: | |
| image_path (str): Image file path | |
| Returns: | |
| img (Tensor): Transformed image. | |
| """ | |
| # Read the image | |
| img = read_image(image_path) | |
| # Wrap image into torchvision tv_tensors | |
| img = tv_tensors.Image(img) | |
| # Apply transformations | |
| transform = get_transform() | |
| img = transform(img) | |
| return [img] | |
| def get_transform(): | |
| """Returns the transformations to be applied to an image. | |
| Returns: | |
| transforms.Compose: The transformations to be applied to the images. | |
| """ | |
| transforms = [] | |
| # Convert the image to a float tensor and scale image | |
| transforms.append(T.ToDtype(torch.float, scale=True)) | |
| # Convert the image to a pure tensor | |
| transforms.append(T.ToPureTensor()) | |
| return T.Compose(transforms) | |
| def draw_boxes_on_image(image, target, confidence_threshold=0.7): | |
| """ | |
| Draw bounding boxes directly on the image. | |
| Args: | |
| image (torch.Tensor): The image tensor (C, H, W). | |
| target (dict): The target dictionary containing boxes, labels, and optionally scores. | |
| confidence_threshold (float): Threshold for displaying predictions (if scores are provided). | |
| Returns: | |
| bytes: The image with boxes drawn, as a bytes object. | |
| """ | |
| colors = ['red', 'blue'] | |
| # Convert image from tensor to PIL Image | |
| img = Image.fromarray((image.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)) | |
| draw = ImageDraw.Draw(img) | |
| boxes = target['boxes'].cpu() | |
| labels = target['labels'].cpu() | |
| scores = target.get('scores', torch.ones_like(labels)) # Use all 1s if no scores | |
| for box, label, score in zip(boxes, labels, scores): | |
| if score < confidence_threshold: | |
| continue | |
| xmin, ymin, xmax, ymax = box.tolist() | |
| color = colors[label.item() - 1] | |
| # Draw rectangle | |
| draw.rectangle([xmin, ymin, xmax, ymax], outline=color, width=6) | |
| # Draw label | |
| draw.text((xmin, ymin), f"{label.item()}", fill=color) | |
| # Convert PIL Image to bytes | |
| buf = io.BytesIO() | |
| img.save(buf, format='PNG') | |
| buf.seek(0) | |
| return buf | |