import torch from torchvision.models.detection import maskrcnn_resnet50_fpn from torchvision.transforms import functional as F from PIL import Image, ImageDraw, ImageFont # Load pretrained COCO model def load_model(): model = maskrcnn_resnet50_fpn(pretrained=True) model.eval() return model model = load_model() # Map COCO classes to your fault types CLASS_MAPPING = { "overheat": ["person", "fire hydrant"], "dust": ["bird", "sheep"], "breakage": ["bench", "truck"] } # COCO id-to-name mapping COCO_LABELS = { 1: "person", 10: "fire hydrant", 16: "bird", 20: "sheep", 14: "bench", 8: "truck", # Add more if needed } def get_fault_label(coco_class_name: str): for fault_type, classes in CLASS_MAPPING.items(): if coco_class_name in classes: return fault_type.capitalize() return None def detect_faults(image: Image.Image, threshold: float = 0.7): image_tensor = F.to_tensor(image).unsqueeze(0) # Convert to batch tensor with torch.no_grad(): outputs = model(image_tensor)[0] # Get first (and only) result draw = ImageDraw.Draw(image) font = ImageFont.load_default() results = [] for score, label, box in zip(outputs["scores"], outputs["labels"], outputs["boxes"]): if score < threshold: continue label_id = label.item() class_name = COCO_LABELS.get(label_id, f"class_{label_id}") fault_type = get_fault_label(class_name) if fault_type: results.append((fault_type, score.item())) # Draw bounding box box = box.tolist() draw.rectangle(box, outline="red", width=2) draw.text((box[0], box[1] - 10), f"{fault_type} ({score:.2f})", fill="red", font=font) return results, image