File size: 6,073 Bytes
4fd9791
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
"""
safety_classifier.py
────────────────────
Classifies a scene caption as SAFE or DANGEROUS using a curated
set of regular-expression patterns grouped by hazard category.

Pipeline step 3: Caption → Regex Engine → ClassificationResult
"""

import re
import logging
from dataclasses import dataclass, field

logger = logging.getLogger(__name__)


# ── Hazard pattern registry ───────────────────────────────────────────────────
# Each entry: (category_name, compiled_regex)
HAZARD_PATTERNS: list[tuple[str, re.Pattern]] = [

    # Fire & heat
    ("fire",        re.compile(r"\b(fire|flame|flames|burning|blaze|smoke|ember|inferno|wildfire|arson)\b", re.I)),
    ("heat",        re.compile(r"\b(hot\s+surface|scalding|steam|boiling|molten)\b", re.I)),

    # Water & weather
    ("flood",       re.compile(r"\b(flood(ing|ed|s)?|flash\s+flood|inundation|submerged|overflow(ing)?)\b", re.I)),
    ("storm",       re.compile(r"\b(storm|lightning|tornado|hurricane|cyclone|typhoon|hail|blizzard)\b", re.I)),

    # Vehicles & traffic
    ("traffic",     re.compile(r"\b(oncoming\s+(car|truck|vehicle|bus|motorcycle)|speeding\s+(car|vehicle)|near\s+collision)\b", re.I)),
    ("crash",       re.compile(r"\b(crash|collision|accident|wreck(age)?|overturned\s+(car|truck|vehicle))\b", re.I)),

    # Weapons & violence
    ("weapon",      re.compile(r"\b(gun|pistol|rifle|shotgun|firearm|knife|blade|sword|machete|weapon|explosive|bomb|grenade)\b", re.I)),
    ("violence",    re.compile(r"\b(fight(ing)?|brawl|riot|mob|attack(ing)?|assault|shooting|stabbing)\b", re.I)),

    # Falls & heights
    ("fall",        re.compile(r"\b(fall(ing|en)?|cliff|ledge|precipice|drop\s+(off|down)|steep\s+(slope|drop)|scaffolding)\b", re.I)),
    ("collapse",    re.compile(r"\b(collaps(ing|ed)|rubble|debris|structural\s+(failure|damage)|cave(-)in)\b", re.I)),

    # Electricity
    ("electrical",  re.compile(r"\b(exposed\s+(wire|cable)|live\s+wire|electr(ic|ical)\s+hazard|power\s+line|sparking)\b", re.I)),

    # Blood / injury
    ("injury",      re.compile(r"\b(blood|bleeding|wound(ed)?|injur(y|ied|ies)|unconscious|laceration|trauma)\b", re.I)),

    # Slips & construction
    ("slip",        re.compile(r"\b(wet\s+floor|slippery|icy\s+(road|surface|path)|black\s+ice)\b", re.I)),
    ("construction",re.compile(r"\b(construction\s+zone|heavy\s+machinery|crane|excavator|unsafe\s+structure)\b", re.I)),

    # Chemical & biological
    ("chemical",    re.compile(r"\b(chemical\s+(spill|leak)|toxic|hazardous\s+material|biohazard|gas\s+leak|fumes?)\b", re.I)),

    # Crowd / panic
    ("crowd",       re.compile(r"\b(stampede|crowd\s+crush|panic(king)?|evacuation|emergency\s+exit)\b", re.I)),

    # General danger keywords
    ("generic",     re.compile(r"\b(danger(ous)?|hazard(ous)?|warning|caution|emergency|critical\s+risk|life-threatening)\b", re.I)),
]


# ── Result dataclass ──────────────────────────────────────────────────────────

@dataclass
class ClassificationResult:
    label    : str                   # "SAFE" or "DANGEROUS"
    hazards  : list[str] = field(default_factory=list)   # matched categories
    matches  : list[str] = field(default_factory=list)   # raw matched tokens

    @property
    def is_dangerous(self) -> bool:
        return self.label == "DANGEROUS"

    def __str__(self) -> str:
        if self.is_dangerous:
            return f"[DANGEROUS] Categories: {', '.join(self.hazards)} | Tokens: {', '.join(self.matches)}"
        return "[SAFE] No hazards detected."


# ── Classifier ────────────────────────────────────────────────────────────────

class SafetyClassifier:
    """
    Applies all HAZARD_PATTERNS to a caption string.
    Returns a ClassificationResult.
    """

    def __init__(self, patterns: list[tuple[str, re.Pattern]] = HAZARD_PATTERNS):
        self.patterns = patterns
        logger.info(f"SafetyClassifier initialised — {len(self.patterns)} hazard patterns loaded.")

    def classify(self, caption: str) -> ClassificationResult:
        """
        Classify a caption string.

        Parameters
        ----------
        caption : str
            Plain-text scene description produced by the captioning model.

        Returns
        -------
        ClassificationResult
        """
        if not caption or not caption.strip():
            return ClassificationResult(label="SAFE")

        matched_categories: list[str] = []
        matched_tokens    : list[str] = []

        for category, pattern in self.patterns:
            hits = pattern.findall(caption)
            if hits:
                matched_categories.append(category)
                # Flatten nested groups from findall
                for hit in hits:
                    token = hit if isinstance(hit, str) else " ".join(h for h in hit if h)
                    if token and token not in matched_tokens:
                        matched_tokens.append(token.strip())

        label = "DANGEROUS" if matched_categories else "SAFE"
        result = ClassificationResult(
            label   = label,
            hazards = list(dict.fromkeys(matched_categories)),   # preserve order, dedupe
            matches = matched_tokens,
        )
        logger.debug(str(result))
        return result

    def explain(self, caption: str) -> dict:
        """
        Returns a detailed breakdown of which patterns fired and why.
        Useful for debugging / transparency.
        """
        breakdown = {}
        for category, pattern in self.patterns:
            hits = pattern.findall(caption)
            if hits:
                breakdown[category] = [h if isinstance(h, str) else " ".join(h) for h in hits]
        return breakdown