Solar_Panel / modules /thermal_fault_detection.py
Tigernawin's picture
Update modules/thermal_fault_detection.py
ac00aa0 verified
import torch
from torchvision.models.detection import maskrcnn_resnet50_fpn
from torchvision.transforms import functional as F
from PIL import Image, ImageDraw, ImageFont
# Load pretrained COCO model
def load_model():
model = maskrcnn_resnet50_fpn(pretrained=True)
model.eval()
return model
model = load_model()
# Map COCO classes to your fault types
CLASS_MAPPING = {
"overheat": ["person", "fire hydrant"],
"dust": ["bird", "sheep"],
"breakage": ["bench", "truck"]
}
# COCO id-to-name mapping
COCO_LABELS = {
1: "person",
10: "fire hydrant",
16: "bird",
20: "sheep",
14: "bench",
8: "truck",
# Add more if needed
}
def get_fault_label(coco_class_name: str):
for fault_type, classes in CLASS_MAPPING.items():
if coco_class_name in classes:
return fault_type.capitalize()
return None
def detect_faults(image: Image.Image, threshold: float = 0.7):
image_tensor = F.to_tensor(image).unsqueeze(0) # Convert to batch tensor
with torch.no_grad():
outputs = model(image_tensor)[0] # Get first (and only) result
draw = ImageDraw.Draw(image)
font = ImageFont.load_default()
results = []
for score, label, box in zip(outputs["scores"], outputs["labels"], outputs["boxes"]):
if score < threshold:
continue
label_id = label.item()
class_name = COCO_LABELS.get(label_id, f"class_{label_id}")
fault_type = get_fault_label(class_name)
if fault_type:
results.append((fault_type, score.item()))
# Draw bounding box
box = box.tolist()
draw.rectangle(box, outline="red", width=2)
draw.text((box[0], box[1] - 10), f"{fault_type} ({score:.2f})", fill="red", font=font)
return results, image