Tigernawin commited on
Commit
ac00aa0
·
verified ·
1 Parent(s): ec20e0f

Update modules/thermal_fault_detection.py

Browse files
Files changed (1) hide show
  1. modules/thermal_fault_detection.py +23 -20
modules/thermal_fault_detection.py CHANGED
@@ -3,59 +3,62 @@ from torchvision.models.detection import maskrcnn_resnet50_fpn
3
  from torchvision.transforms import functional as F
4
  from PIL import Image, ImageDraw, ImageFont
5
 
6
- # Load model - assuming you use a pretrained COCO version
7
  def load_model():
8
- model = maskrcnn_resnet50_fpn(pretrained=True) # or load your fine-tuned weights
9
  model.eval()
10
  return model
11
 
12
  model = load_model()
13
 
14
- # Map COCO classes to your custom fault categories
15
  CLASS_MAPPING = {
16
  "overheat": ["person", "fire hydrant"],
17
  "dust": ["bird", "sheep"],
18
  "breakage": ["bench", "truck"]
19
  }
20
 
 
 
 
 
 
 
 
 
 
 
 
21
  def get_fault_label(coco_class_name: str):
22
- for fault, coco_aliases in CLASS_MAPPING.items():
23
- if coco_class_name in coco_aliases:
24
- return fault.capitalize()
25
  return None
26
 
27
  def detect_faults(image: Image.Image, threshold: float = 0.7):
28
- image_tensor = F.to_tensor(image).unsqueeze(0) # Convert to tensor and add batch dimension
29
 
30
  with torch.no_grad():
31
- output = model(image_tensor)[0]
32
 
33
  draw = ImageDraw.Draw(image)
34
  font = ImageFont.load_default()
35
  results = []
36
 
37
- for score, label, box in zip(output["scores"], output["labels"], output["boxes"]):
38
  if score < threshold:
39
  continue
40
 
41
  label_id = label.item()
42
- coco_label = model.coco_labels[label_id] if hasattr(model, "coco_labels") else label_id
43
- class_name = model.coco_labels.get(label_id, f"class_{label_id}") if hasattr(model, "coco_labels") else f"class_{label_id}"
44
-
45
- # COCO id2label mapping from torchvision
46
- COCO_LABELS = {
47
- 1: "person", 10: "fire hydrant", 16: "bird", 20: "sheep", 14: "bench", 8: "truck",
48
- # Add more if needed
49
- }
50
  class_name = COCO_LABELS.get(label_id, f"class_{label_id}")
51
  fault_type = get_fault_label(class_name)
52
 
53
  if fault_type:
54
  results.append((fault_type, score.item()))
55
 
56
- # Draw box
57
  box = box.tolist()
58
- draw.rectangle(box, outline="red", width=3)
59
- draw.text((box[0], box[1] - 10), f"{fault_type}: {score:.2f}", fill="red", font=font)
60
 
61
  return results, image
 
3
  from torchvision.transforms import functional as F
4
  from PIL import Image, ImageDraw, ImageFont
5
 
6
+ # Load pretrained COCO model
7
  def load_model():
8
+ model = maskrcnn_resnet50_fpn(pretrained=True)
9
  model.eval()
10
  return model
11
 
12
  model = load_model()
13
 
14
+ # Map COCO classes to your fault types
15
  CLASS_MAPPING = {
16
  "overheat": ["person", "fire hydrant"],
17
  "dust": ["bird", "sheep"],
18
  "breakage": ["bench", "truck"]
19
  }
20
 
21
+ # COCO id-to-name mapping
22
+ COCO_LABELS = {
23
+ 1: "person",
24
+ 10: "fire hydrant",
25
+ 16: "bird",
26
+ 20: "sheep",
27
+ 14: "bench",
28
+ 8: "truck",
29
+ # Add more if needed
30
+ }
31
+
32
  def get_fault_label(coco_class_name: str):
33
+ for fault_type, classes in CLASS_MAPPING.items():
34
+ if coco_class_name in classes:
35
+ return fault_type.capitalize()
36
  return None
37
 
38
  def detect_faults(image: Image.Image, threshold: float = 0.7):
39
+ image_tensor = F.to_tensor(image).unsqueeze(0) # Convert to batch tensor
40
 
41
  with torch.no_grad():
42
+ outputs = model(image_tensor)[0] # Get first (and only) result
43
 
44
  draw = ImageDraw.Draw(image)
45
  font = ImageFont.load_default()
46
  results = []
47
 
48
+ for score, label, box in zip(outputs["scores"], outputs["labels"], outputs["boxes"]):
49
  if score < threshold:
50
  continue
51
 
52
  label_id = label.item()
 
 
 
 
 
 
 
 
53
  class_name = COCO_LABELS.get(label_id, f"class_{label_id}")
54
  fault_type = get_fault_label(class_name)
55
 
56
  if fault_type:
57
  results.append((fault_type, score.item()))
58
 
59
+ # Draw bounding box
60
  box = box.tolist()
61
+ draw.rectangle(box, outline="red", width=2)
62
+ draw.text((box[0], box[1] - 10), f"{fault_type} ({score:.2f})", fill="red", font=font)
63
 
64
  return results, image