""" safety_classifier.py ──────────────────── Classifies a scene caption as SAFE or DANGEROUS using a curated set of regular-expression patterns grouped by hazard category. Pipeline step 3: Caption → Regex Engine → ClassificationResult """ import re import logging from dataclasses import dataclass, field logger = logging.getLogger(__name__) # ── Hazard pattern registry ─────────────────────────────────────────────────── # Each entry: (category_name, compiled_regex) HAZARD_PATTERNS: list[tuple[str, re.Pattern]] = [ # Fire & heat ("fire", re.compile(r"\b(fire|flame|flames|burning|blaze|smoke|ember|inferno|wildfire|arson)\b", re.I)), ("heat", re.compile(r"\b(hot\s+surface|scalding|steam|boiling|molten)\b", re.I)), # Water & weather ("flood", re.compile(r"\b(flood(ing|ed|s)?|flash\s+flood|inundation|submerged|overflow(ing)?)\b", re.I)), ("storm", re.compile(r"\b(storm|lightning|tornado|hurricane|cyclone|typhoon|hail|blizzard)\b", re.I)), # Vehicles & traffic ("traffic", re.compile(r"\b(oncoming\s+(car|truck|vehicle|bus|motorcycle)|speeding\s+(car|vehicle)|near\s+collision)\b", re.I)), ("crash", re.compile(r"\b(crash|collision|accident|wreck(age)?|overturned\s+(car|truck|vehicle))\b", re.I)), # Weapons & violence ("weapon", re.compile(r"\b(gun|pistol|rifle|shotgun|firearm|knife|blade|sword|machete|weapon|explosive|bomb|grenade)\b", re.I)), ("violence", re.compile(r"\b(fight(ing)?|brawl|riot|mob|attack(ing)?|assault|shooting|stabbing)\b", re.I)), # Falls & heights ("fall", re.compile(r"\b(fall(ing|en)?|cliff|ledge|precipice|drop\s+(off|down)|steep\s+(slope|drop)|scaffolding)\b", re.I)), ("collapse", re.compile(r"\b(collaps(ing|ed)|rubble|debris|structural\s+(failure|damage)|cave(-)in)\b", re.I)), # Electricity ("electrical", re.compile(r"\b(exposed\s+(wire|cable)|live\s+wire|electr(ic|ical)\s+hazard|power\s+line|sparking)\b", re.I)), # Blood / injury ("injury", re.compile(r"\b(blood|bleeding|wound(ed)?|injur(y|ied|ies)|unconscious|laceration|trauma)\b", re.I)), # Slips & construction ("slip", re.compile(r"\b(wet\s+floor|slippery|icy\s+(road|surface|path)|black\s+ice)\b", re.I)), ("construction",re.compile(r"\b(construction\s+zone|heavy\s+machinery|crane|excavator|unsafe\s+structure)\b", re.I)), # Chemical & biological ("chemical", re.compile(r"\b(chemical\s+(spill|leak)|toxic|hazardous\s+material|biohazard|gas\s+leak|fumes?)\b", re.I)), # Crowd / panic ("crowd", re.compile(r"\b(stampede|crowd\s+crush|panic(king)?|evacuation|emergency\s+exit)\b", re.I)), # General danger keywords ("generic", re.compile(r"\b(danger(ous)?|hazard(ous)?|warning|caution|emergency|critical\s+risk|life-threatening)\b", re.I)), ] # ── Result dataclass ────────────────────────────────────────────────────────── @dataclass class ClassificationResult: label : str # "SAFE" or "DANGEROUS" hazards : list[str] = field(default_factory=list) # matched categories matches : list[str] = field(default_factory=list) # raw matched tokens @property def is_dangerous(self) -> bool: return self.label == "DANGEROUS" def __str__(self) -> str: if self.is_dangerous: return f"[DANGEROUS] Categories: {', '.join(self.hazards)} | Tokens: {', '.join(self.matches)}" return "[SAFE] No hazards detected." # ── Classifier ──────────────────────────────────────────────────────────────── class SafetyClassifier: """ Applies all HAZARD_PATTERNS to a caption string. Returns a ClassificationResult. """ def __init__(self, patterns: list[tuple[str, re.Pattern]] = HAZARD_PATTERNS): self.patterns = patterns logger.info(f"SafetyClassifier initialised — {len(self.patterns)} hazard patterns loaded.") def classify(self, caption: str) -> ClassificationResult: """ Classify a caption string. Parameters ---------- caption : str Plain-text scene description produced by the captioning model. Returns ------- ClassificationResult """ if not caption or not caption.strip(): return ClassificationResult(label="SAFE") matched_categories: list[str] = [] matched_tokens : list[str] = [] for category, pattern in self.patterns: hits = pattern.findall(caption) if hits: matched_categories.append(category) # Flatten nested groups from findall for hit in hits: token = hit if isinstance(hit, str) else " ".join(h for h in hit if h) if token and token not in matched_tokens: matched_tokens.append(token.strip()) label = "DANGEROUS" if matched_categories else "SAFE" result = ClassificationResult( label = label, hazards = list(dict.fromkeys(matched_categories)), # preserve order, dedupe matches = matched_tokens, ) logger.debug(str(result)) return result def explain(self, caption: str) -> dict: """ Returns a detailed breakdown of which patterns fired and why. Useful for debugging / transparency. """ breakdown = {} for category, pattern in self.patterns: hits = pattern.findall(caption) if hits: breakdown[category] = [h if isinstance(h, str) else " ".join(h) for h in hits] return breakdown