Tigernawin commited on
Commit
ec20e0f
·
verified ·
1 Parent(s): cd4bf3c

Update modules/thermal_fault_detection.py

Browse files
Files changed (1) hide show
  1. modules/thermal_fault_detection.py +52 -24
modules/thermal_fault_detection.py CHANGED
@@ -1,33 +1,61 @@
1
- # modules/thermal_fault_detection.py
2
- from transformers import DetrImageProcessor, DetrForObjectDetection
3
- from PIL import Image
4
  import torch
 
 
 
5
 
6
- # Load processor and model
7
- processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
8
- model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
 
 
9
 
10
- # Define class mapping for your use case
11
- CUSTOM_CLASSES = {
 
 
12
  "overheat": ["person", "fire hydrant"],
13
  "dust": ["bird", "sheep"],
14
  "breakage": ["bench", "truck"]
15
  }
16
 
 
 
 
 
 
 
17
  def detect_faults(image: Image.Image, threshold: float = 0.7):
18
- inputs = processor(images=image, return_tensors="pt")
19
- outputs = model(**inputs)
20
-
21
- target_sizes = torch.tensor([image.size[::-1]])
22
- results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=threshold)[0]
23
-
24
- faults = []
25
-
26
- for score, label in zip(results["scores"], results["labels"]):
27
- class_name = model.config.id2label[label.item()].lower()
28
- for fault_type, tags in CUSTOM_CLASSES.items():
29
- if class_name in tags:
30
- faults.append((fault_type.capitalize(), score.item()))
31
- break
32
-
33
- return faults
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ from torchvision.models.detection import maskrcnn_resnet50_fpn
3
+ from torchvision.transforms import functional as F
4
+ from PIL import Image, ImageDraw, ImageFont
5
 
6
+ # Load model - assuming you use a pretrained COCO version
7
+ def load_model():
8
+ model = maskrcnn_resnet50_fpn(pretrained=True) # or load your fine-tuned weights
9
+ model.eval()
10
+ return model
11
 
12
+ model = load_model()
13
+
14
+ # Map COCO classes to your custom fault categories
15
+ CLASS_MAPPING = {
16
  "overheat": ["person", "fire hydrant"],
17
  "dust": ["bird", "sheep"],
18
  "breakage": ["bench", "truck"]
19
  }
20
 
21
+ def get_fault_label(coco_class_name: str):
22
+ for fault, coco_aliases in CLASS_MAPPING.items():
23
+ if coco_class_name in coco_aliases:
24
+ return fault.capitalize()
25
+ return None
26
+
27
  def detect_faults(image: Image.Image, threshold: float = 0.7):
28
+ image_tensor = F.to_tensor(image).unsqueeze(0) # Convert to tensor and add batch dimension
29
+
30
+ with torch.no_grad():
31
+ output = model(image_tensor)[0]
32
+
33
+ draw = ImageDraw.Draw(image)
34
+ font = ImageFont.load_default()
35
+ results = []
36
+
37
+ for score, label, box in zip(output["scores"], output["labels"], output["boxes"]):
38
+ if score < threshold:
39
+ continue
40
+
41
+ label_id = label.item()
42
+ coco_label = model.coco_labels[label_id] if hasattr(model, "coco_labels") else label_id
43
+ class_name = model.coco_labels.get(label_id, f"class_{label_id}") if hasattr(model, "coco_labels") else f"class_{label_id}"
44
+
45
+ # COCO id2label mapping from torchvision
46
+ COCO_LABELS = {
47
+ 1: "person", 10: "fire hydrant", 16: "bird", 20: "sheep", 14: "bench", 8: "truck",
48
+ # Add more if needed
49
+ }
50
+ class_name = COCO_LABELS.get(label_id, f"class_{label_id}")
51
+ fault_type = get_fault_label(class_name)
52
+
53
+ if fault_type:
54
+ results.append((fault_type, score.item()))
55
+
56
+ # Draw box
57
+ box = box.tolist()
58
+ draw.rectangle(box, outline="red", width=3)
59
+ draw.text((box[0], box[1] - 10), f"{fault_type}: {score:.2f}", fill="red", font=font)
60
+
61
+ return results, image