DSatishchandra commited on
Commit
72e5fb9
Β·
verified Β·
1 Parent(s): 0d70434

Create detection_service.py

Browse files
Files changed (1) hide show
  1. services/detection_service.py +32 -0
services/detection_service.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import DetrImageProcessor, DetrForObjectDetection
2
+ import torch
3
+ from PIL import Image
4
+
5
+ class DetectionService:
6
+ def __init__(self, model_name="facebook/detr-resnet-50"):
7
+ self.processor = DetrImageProcessor.from_pretrained(model_name, revision="no_timm")
8
+ self.model = DetrForObjectDetection.from_pretrained(model_name, revision="no_timm")
9
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+ self.model.to(self.device)
11
+ self.model.eval()
12
+
13
+ def detect_objects(self, image, confidence_threshold=0.9):
14
+ """Detect objects in an image."""
15
+ inputs = self.processor(images=image, return_tensors="pt").to(self.device)
16
+ with torch.no_grad():
17
+ outputs = self.model(**inputs)
18
+ target_sizes = torch.tensor([image.size[::-1]]).to(self.device)
19
+ results = self.processor.post_process_object_detection(
20
+ outputs, target_sizes=target_sizes, threshold=confidence_threshold
21
+ )[0]
22
+ detections = []
23
+ for score, label, box in zip(
24
+ results["scores"], results["labels"], results["boxes"]
25
+ ):
26
+ box = box.cpu().numpy().astype(int)
27
+ detections.append({
28
+ "score": score.item(),
29
+ "label": self.model.config.id2label[label.item()],
30
+ "box": {"xmin": box[0], "ymin": box[1], "xmax": box[2], "ymax": box[3]}
31
+ })
32
+ return detections