Tigernawin commited on
Commit
d287691
·
verified ·
1 Parent(s): f4b81d7

Create modules/thermal_fault_detection.py

Browse files
Files changed (1) hide show
  1. modules/thermal_fault_detection.py +33 -0
modules/thermal_fault_detection.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modules/thermal_fault_detection.py
2
+ from transformers import DetrImageProcessor, DetrForObjectDetection
3
+ from PIL import Image
4
+ import torch
5
+
6
+ # Load processor and model
7
+ processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
8
+ model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
9
+
10
+ # Define class mapping for your use case
11
+ CUSTOM_CLASSES = {
12
+ "overheat": ["person", "fire hydrant"],
13
+ "dust": ["bird", "sheep"],
14
+ "breakage": ["bench", "truck"]
15
+ }
16
+
17
+ def detect_faults(image: Image.Image, threshold: float = 0.7):
18
+ inputs = processor(images=image, return_tensors="pt")
19
+ outputs = model(**inputs)
20
+
21
+ target_sizes = torch.tensor([image.size[::-1]])
22
+ results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=threshold)[0]
23
+
24
+ faults = []
25
+
26
+ for score, label in zip(results["scores"], results["labels"]):
27
+ class_name = model.config.id2label[label.item()].lower()
28
+ for fault_type, tags in CUSTOM_CLASSES.items():
29
+ if class_name in tags:
30
+ faults.append((fault_type.capitalize(), score.item()))
31
+ break
32
+
33
+ return faults