Spaces:
Build error
Build error
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import albumentations as A | |
| from albumentations.pytorch import ToTensorV2 | |
| from huggingface_hub import hf_hub_download | |
| import gradio as gr | |
| class ObjectDetection: | |
| def __init__(self, ckpt_path): | |
| self.test_transform = A.Compose( | |
| [ | |
| A.Resize(800, 600), | |
| A.CLAHE(clip_limit=10, p=1), | |
| A.Normalize( | |
| [0.29278653, 0.25276296, 0.22975405], | |
| [0.22653664, 0.19836408, 0.17775835], | |
| ), | |
| ToTensorV2(), | |
| ], | |
| ) | |
| self.model = torch.hub.load( | |
| "facebookresearch/detr", "detr_resnet50", pretrained=False | |
| ) | |
| in_features = self.model.class_embed.in_features | |
| self.model.class_embed = nn.Linear( | |
| in_features=in_features, | |
| out_features=12, | |
| ) | |
| self.labels = [ | |
| "Dog", | |
| "Motorbike", | |
| "People", | |
| "Cat", | |
| "Chair", | |
| "Table", | |
| "Car", | |
| "Bicycle", | |
| "Bottle", | |
| "Bus", | |
| "Cup", | |
| "Boat", | |
| ] | |
| model_ckpt = torch.load(ckpt_path, map_location=torch.device("cpu")) | |
| self.model.load_state_dict(model_ckpt) | |
| self.model.eval() | |
| def predict(self, img): | |
| score_threshold, iou_threshold = 0.05, 0.1 | |
| img_w, img_h = img.size | |
| inp = self.test_transform(image=np.array(img.convert("RGB")))["image"] | |
| out = self.model(inp.unsqueeze(0)) | |
| probas = out["pred_logits"].softmax(-1)[0, :, :-1] | |
| bboxes = [] | |
| scores = [] | |
| for idx, bbox in enumerate(out["pred_boxes"][0]): | |
| if not probas[idx].max().item() >= score_threshold: | |
| continue | |
| x_c, y_c, w, h = bbox.detach().numpy() | |
| x1 = int((x_c - w * 0.5) * img_w) | |
| y1 = int((y_c - h * 0.5) * img_h) | |
| x2 = int((x_c + w * 0.5) * img_w) | |
| y2 = int((y_c + h * 0.5) * img_h) | |
| label_idx = probas[idx].argmax().item() | |
| label = self.labels[label_idx] + f" {probas[idx].max().item():.2f}" | |
| bboxes.append(((x1, y1, x2, y2), label)) | |
| scores.append(probas[idx].max().item()) | |
| selected_indices = self.non_max_suppression( | |
| bboxes, | |
| scores, | |
| iou_threshold, | |
| ) | |
| bboxes = [bboxes[i] for i in selected_indices] | |
| return (img, bboxes) | |
| def non_max_suppression(self, boxes, scores, iou_threshold): | |
| if len(boxes) == 0: | |
| return [] | |
| sorted_indices = sorted( | |
| range(len(scores)), key=lambda i: scores[i], reverse=True | |
| ) | |
| selected_indices = [] | |
| while sorted_indices: | |
| current_index = sorted_indices[0] | |
| selected_indices.append(current_index) | |
| sorted_indices.pop(0) | |
| ious = [ | |
| self.calculate_iou(boxes[current_index][0], boxes[i][0]) | |
| for i in sorted_indices | |
| ] | |
| indices_to_remove = [i for i, iou in enumerate(ious) if iou > iou_threshold] | |
| sorted_indices = [ | |
| i for j, i in enumerate(sorted_indices) if j not in indices_to_remove | |
| ] | |
| return selected_indices | |
| def calculate_iou(self, box1, box2): | |
| """ | |
| Calculate the Intersection over Union (IoU) of two bounding boxes. | |
| Args: | |
| box1: [x1, y1, x2, y2] for the first box. | |
| box2: [x1, y1, x2, y2] for the second box. | |
| Returns: | |
| IoU value. | |
| """ | |
| x1 = max(box1[0], box2[0]) | |
| y1 = max(box1[1], box2[1]) | |
| x2 = min(box1[2], box2[2]) | |
| y2 = min(box1[3], box2[3]) | |
| intersection_area = max(0, x2 - x1) * max(0, y2 - y1) | |
| box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1]) | |
| box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1]) | |
| iou = intersection_area / (box1_area + box2_area - intersection_area) | |
| return iou | |
| model_path = hf_hub_download( | |
| repo_id="SatwikKambham/detr_low_light", | |
| filename="detr.pt", | |
| ) | |
| detector = ObjectDetection(ckpt_path=model_path) | |
| iface = gr.Interface( | |
| fn=detector.predict, | |
| inputs=[ | |
| gr.Image( | |
| type="pil", | |
| label="Input", | |
| height=400, | |
| ), | |
| ], | |
| outputs=gr.AnnotatedImage( | |
| height=400, | |
| ), | |
| examples="Examples", | |
| ) | |
| iface.launch() | |