Real_Time_Image_Captioning / safety_classifier.py
A7med-Ame3's picture
Upload 7 files
4fd9791 verified
"""
safety_classifier.py
────────────────────
Classifies a scene caption as SAFE or DANGEROUS using a curated
set of regular-expression patterns grouped by hazard category.
Pipeline step 3: Caption → Regex Engine → ClassificationResult
"""
import re
import logging
from dataclasses import dataclass, field
logger = logging.getLogger(__name__)
# ── Hazard pattern registry ───────────────────────────────────────────────────
# Each entry: (category_name, compiled_regex)
HAZARD_PATTERNS: list[tuple[str, re.Pattern]] = [
# Fire & heat
("fire", re.compile(r"\b(fire|flame|flames|burning|blaze|smoke|ember|inferno|wildfire|arson)\b", re.I)),
("heat", re.compile(r"\b(hot\s+surface|scalding|steam|boiling|molten)\b", re.I)),
# Water & weather
("flood", re.compile(r"\b(flood(ing|ed|s)?|flash\s+flood|inundation|submerged|overflow(ing)?)\b", re.I)),
("storm", re.compile(r"\b(storm|lightning|tornado|hurricane|cyclone|typhoon|hail|blizzard)\b", re.I)),
# Vehicles & traffic
("traffic", re.compile(r"\b(oncoming\s+(car|truck|vehicle|bus|motorcycle)|speeding\s+(car|vehicle)|near\s+collision)\b", re.I)),
("crash", re.compile(r"\b(crash|collision|accident|wreck(age)?|overturned\s+(car|truck|vehicle))\b", re.I)),
# Weapons & violence
("weapon", re.compile(r"\b(gun|pistol|rifle|shotgun|firearm|knife|blade|sword|machete|weapon|explosive|bomb|grenade)\b", re.I)),
("violence", re.compile(r"\b(fight(ing)?|brawl|riot|mob|attack(ing)?|assault|shooting|stabbing)\b", re.I)),
# Falls & heights
("fall", re.compile(r"\b(fall(ing|en)?|cliff|ledge|precipice|drop\s+(off|down)|steep\s+(slope|drop)|scaffolding)\b", re.I)),
("collapse", re.compile(r"\b(collaps(ing|ed)|rubble|debris|structural\s+(failure|damage)|cave(-)in)\b", re.I)),
# Electricity
("electrical", re.compile(r"\b(exposed\s+(wire|cable)|live\s+wire|electr(ic|ical)\s+hazard|power\s+line|sparking)\b", re.I)),
# Blood / injury
("injury", re.compile(r"\b(blood|bleeding|wound(ed)?|injur(y|ied|ies)|unconscious|laceration|trauma)\b", re.I)),
# Slips & construction
("slip", re.compile(r"\b(wet\s+floor|slippery|icy\s+(road|surface|path)|black\s+ice)\b", re.I)),
("construction",re.compile(r"\b(construction\s+zone|heavy\s+machinery|crane|excavator|unsafe\s+structure)\b", re.I)),
# Chemical & biological
("chemical", re.compile(r"\b(chemical\s+(spill|leak)|toxic|hazardous\s+material|biohazard|gas\s+leak|fumes?)\b", re.I)),
# Crowd / panic
("crowd", re.compile(r"\b(stampede|crowd\s+crush|panic(king)?|evacuation|emergency\s+exit)\b", re.I)),
# General danger keywords
("generic", re.compile(r"\b(danger(ous)?|hazard(ous)?|warning|caution|emergency|critical\s+risk|life-threatening)\b", re.I)),
]
# ── Result dataclass ──────────────────────────────────────────────────────────
@dataclass
class ClassificationResult:
label : str # "SAFE" or "DANGEROUS"
hazards : list[str] = field(default_factory=list) # matched categories
matches : list[str] = field(default_factory=list) # raw matched tokens
@property
def is_dangerous(self) -> bool:
return self.label == "DANGEROUS"
def __str__(self) -> str:
if self.is_dangerous:
return f"[DANGEROUS] Categories: {', '.join(self.hazards)} | Tokens: {', '.join(self.matches)}"
return "[SAFE] No hazards detected."
# ── Classifier ────────────────────────────────────────────────────────────────
class SafetyClassifier:
"""
Applies all HAZARD_PATTERNS to a caption string.
Returns a ClassificationResult.
"""
def __init__(self, patterns: list[tuple[str, re.Pattern]] = HAZARD_PATTERNS):
self.patterns = patterns
logger.info(f"SafetyClassifier initialised — {len(self.patterns)} hazard patterns loaded.")
def classify(self, caption: str) -> ClassificationResult:
"""
Classify a caption string.
Parameters
----------
caption : str
Plain-text scene description produced by the captioning model.
Returns
-------
ClassificationResult
"""
if not caption or not caption.strip():
return ClassificationResult(label="SAFE")
matched_categories: list[str] = []
matched_tokens : list[str] = []
for category, pattern in self.patterns:
hits = pattern.findall(caption)
if hits:
matched_categories.append(category)
# Flatten nested groups from findall
for hit in hits:
token = hit if isinstance(hit, str) else " ".join(h for h in hit if h)
if token and token not in matched_tokens:
matched_tokens.append(token.strip())
label = "DANGEROUS" if matched_categories else "SAFE"
result = ClassificationResult(
label = label,
hazards = list(dict.fromkeys(matched_categories)), # preserve order, dedupe
matches = matched_tokens,
)
logger.debug(str(result))
return result
def explain(self, caption: str) -> dict:
"""
Returns a detailed breakdown of which patterns fired and why.
Useful for debugging / transparency.
"""
breakdown = {}
for category, pattern in self.patterns:
hits = pattern.findall(caption)
if hits:
breakdown[category] = [h if isinstance(h, str) else " ".join(h) for h in hits]
return breakdown