Solar_Panel_Faults_Detection / services /detection_service.py
DSatishchandra's picture
Update services/detection_service.py
f1fb617 verified
from transformers import DetrImageProcessor, DetrForObjectDetection
import torch
from PIL import Image
class DetectionService:
def __init__(self, model_name="facebook/detr-resnet-50"):
self.processor = DetrImageProcessor.from_pretrained(model_name, revision="no_timm")
self.model = DetrForObjectDetection.from_pretrained(model_name, revision="no_timm")
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device)
self.model.eval()
self.frame_counter = 0
self.frame_skip = 5 # Process every 5th frame for performance
def detect_objects(self, image, confidence_threshold=0.9):
"""Detect objects in an image, skipping frames for performance."""
self.frame_counter += 1
if self.frame_counter % self.frame_skip != 0:
return [] # Skip detection for this frame
inputs = self.processor(images=image, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
target_sizes = torch.tensor([image.size[::-1]]).to(self.device)
results = self.processor.post_process_object_detection(
outputs, target_sizes=target_sizes, threshold=confidence_threshold
)[0]
detections = []
for score, label, box in zip(
results["scores"], results["labels"], results["boxes"]
):
box = box.cpu().numpy().astype(int)
detections.append({
"score": score.item(),
"label": self.model.config.id2label[label.item()],
"box": {"xmin": box[0], "ymin": box[1], "xmax": box[2], "ymax": box[3]}
})
return detections