Spaces:
Sleeping
Sleeping
File size: 1,814 Bytes
d287691 ec20e0f d287691 ac00aa0 ec20e0f ac00aa0 ec20e0f d287691 ec20e0f ac00aa0 ec20e0f d287691 ac00aa0 ec20e0f ac00aa0 ec20e0f d287691 ac00aa0 ec20e0f ac00aa0 ec20e0f ac00aa0 ec20e0f ac00aa0 ec20e0f ac00aa0 ec20e0f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
import torch
from torchvision.models.detection import maskrcnn_resnet50_fpn
from torchvision.transforms import functional as F
from PIL import Image, ImageDraw, ImageFont
# Load pretrained COCO model
def load_model():
model = maskrcnn_resnet50_fpn(pretrained=True)
model.eval()
return model
model = load_model()
# Map COCO classes to your fault types
CLASS_MAPPING = {
"overheat": ["person", "fire hydrant"],
"dust": ["bird", "sheep"],
"breakage": ["bench", "truck"]
}
# COCO id-to-name mapping
COCO_LABELS = {
1: "person",
10: "fire hydrant",
16: "bird",
20: "sheep",
14: "bench",
8: "truck",
# Add more if needed
}
def get_fault_label(coco_class_name: str):
for fault_type, classes in CLASS_MAPPING.items():
if coco_class_name in classes:
return fault_type.capitalize()
return None
def detect_faults(image: Image.Image, threshold: float = 0.7):
image_tensor = F.to_tensor(image).unsqueeze(0) # Convert to batch tensor
with torch.no_grad():
outputs = model(image_tensor)[0] # Get first (and only) result
draw = ImageDraw.Draw(image)
font = ImageFont.load_default()
results = []
for score, label, box in zip(outputs["scores"], outputs["labels"], outputs["boxes"]):
if score < threshold:
continue
label_id = label.item()
class_name = COCO_LABELS.get(label_id, f"class_{label_id}")
fault_type = get_fault_label(class_name)
if fault_type:
results.append((fault_type, score.item()))
# Draw bounding box
box = box.tolist()
draw.rectangle(box, outline="red", width=2)
draw.text((box[0], box[1] - 10), f"{fault_type} ({score:.2f})", fill="red", font=font)
return results, image
|