Tigernawin commited on
Commit
627f06b
·
verified ·
1 Parent(s): 246b3ce

Update modules/thermal_fault_detection.py

Browse files
Files changed (1) hide show
  1. modules/thermal_fault_detection.py +17 -29
modules/thermal_fault_detection.py CHANGED
@@ -1,35 +1,23 @@
1
- from transformers import DetrImageProcessor, DetrForObjectDetection
2
- import torch
3
  from PIL import Image
 
4
 
5
- # Load processor and model
6
- processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
7
- model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")
8
 
9
- CUSTOM_CLASSES = ["Overheating", "Dust", "Breakage"]
 
 
 
10
 
11
- def detect_faults(image, threshold=0.1):
12
- # Preprocess the image
13
- inputs = processor(images=image, return_tensors="pt")
14
- outputs = model(**inputs)
 
 
 
15
 
16
- # ✅ Define probs from logits
17
- probs = outputs.logits.softmax(-1)[0] # [num_predictions, num_classes]
18
- boxes = outputs.pred_boxes[0] # [num_predictions, 4]
19
 
20
- results = []
21
-
22
- for i, prob in enumerate(probs):
23
- label_id = prob.argmax().item()
24
- confidence = prob[label_id].item()
25
-
26
- # Get label safely
27
- label = model.config.id2label.get(label_id, None)
28
- if label is None or confidence < threshold:
29
- continue
30
-
31
- if label in CUSTOM_CLASSES:
32
- box = boxes[i].tolist()
33
- results.append((label, confidence, box))
34
-
35
- return results
 
1
+ from ultralytics import YOLO
 
2
  from PIL import Image
3
+ import numpy as np
4
 
5
+ # Load the YOLOv8 model (make sure 'solar_fault.pt' exists in your root directory)
6
+ model = YOLO("solar_fault.pt")
 
7
 
8
+ def detect_faults(image, threshold=0.3):
9
+ # Convert image to numpy array
10
+ results = model.predict(np.array(image), conf=threshold)
11
+ detections = []
12
 
13
+ for r in results:
14
+ for box in r.boxes:
15
+ label_id = int(box.cls[0])
16
+ confidence = float(box.conf[0])
17
+ x1, y1, x2, y2 = map(int, box.xyxy[0])
18
+ label = model.names[label_id]
19
+ detections.append((label, confidence, (x1, y1, x2, y2)))
20
 
21
+ return detections
 
 
22
 
23
+ CUSTOM_CLASSES = ["crack", "burn", "hotspot", "dust"]