File size: 1,814 Bytes
d287691
ec20e0f
 
 
d287691
ac00aa0
ec20e0f
ac00aa0
ec20e0f
 
d287691
ec20e0f
 
ac00aa0
ec20e0f
d287691
 
 
 
 
ac00aa0
 
 
 
 
 
 
 
 
 
 
ec20e0f
ac00aa0
 
 
ec20e0f
 
d287691
ac00aa0
ec20e0f
 
ac00aa0
ec20e0f
 
 
 
 
ac00aa0
ec20e0f
 
 
 
 
 
 
 
 
 
ac00aa0
ec20e0f
ac00aa0
 
ec20e0f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
from torchvision.models.detection import maskrcnn_resnet50_fpn
from torchvision.transforms import functional as F
from PIL import Image, ImageDraw, ImageFont

# Load pretrained COCO model
def load_model():
    model = maskrcnn_resnet50_fpn(pretrained=True)
    model.eval()
    return model

model = load_model()

# Map COCO classes to your fault types
CLASS_MAPPING = {
    "overheat": ["person", "fire hydrant"],
    "dust": ["bird", "sheep"],
    "breakage": ["bench", "truck"]
}

# COCO id-to-name mapping
COCO_LABELS = {
    1: "person",
    10: "fire hydrant",
    16: "bird",
    20: "sheep",
    14: "bench",
    8: "truck",
    # Add more if needed
}

def get_fault_label(coco_class_name: str):
    for fault_type, classes in CLASS_MAPPING.items():
        if coco_class_name in classes:
            return fault_type.capitalize()
    return None

def detect_faults(image: Image.Image, threshold: float = 0.7):
    image_tensor = F.to_tensor(image).unsqueeze(0)  # Convert to batch tensor

    with torch.no_grad():
        outputs = model(image_tensor)[0]  # Get first (and only) result

    draw = ImageDraw.Draw(image)
    font = ImageFont.load_default()
    results = []

    for score, label, box in zip(outputs["scores"], outputs["labels"], outputs["boxes"]):
        if score < threshold:
            continue

        label_id = label.item()
        class_name = COCO_LABELS.get(label_id, f"class_{label_id}")
        fault_type = get_fault_label(class_name)

        if fault_type:
            results.append((fault_type, score.item()))

            # Draw bounding box
            box = box.tolist()
            draw.rectangle(box, outline="red", width=2)
            draw.text((box[0], box[1] - 10), f"{fault_type} ({score:.2f})", fill="red", font=font)

    return results, image