Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -18,7 +18,7 @@ from retrying import retry
|
|
| 18 |
# Configuration
|
| 19 |
# ==========================
|
| 20 |
CONFIG = {
|
| 21 |
-
"MODEL_PATH": "yolov8n.pt", #
|
| 22 |
"OUTPUT_DIR": "static/output",
|
| 23 |
"VIOLATION_LABELS": {
|
| 24 |
0: "no_helmet",
|
|
@@ -33,8 +33,9 @@ CONFIG = {
|
|
| 33 |
"domain": "login"
|
| 34 |
},
|
| 35 |
"PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo1/resolve/main/static/output/",
|
| 36 |
-
"FRAME_SKIP": 15, #
|
| 37 |
-
"MAX_PROCESSING_TIME": 25 # Cap video processing at 25s
|
|
|
|
| 38 |
}
|
| 39 |
|
| 40 |
# Setup logging
|
|
@@ -57,6 +58,7 @@ def load_model():
|
|
| 57 |
try:
|
| 58 |
model = YOLO(CONFIG["MODEL_PATH"]).to(device)
|
| 59 |
logger.info(f"Model loaded: {CONFIG['MODEL_PATH']}")
|
|
|
|
| 60 |
return model
|
| 61 |
except Exception as e:
|
| 62 |
logger.error(f"Failed to load model: {e}")
|
|
@@ -250,8 +252,16 @@ def process_video(video_data):
|
|
| 250 |
for box in result.boxes:
|
| 251 |
cls, conf = int(box.cls), float(box.conf)
|
| 252 |
label = CONFIG["VIOLATION_LABELS"].get(cls, f"class_{cls}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 253 |
if label not in ["no_helmet", "no_harness", "unsafe_posture"]:
|
| 254 |
continue
|
|
|
|
|
|
|
|
|
|
| 255 |
if label in seen_violations:
|
| 256 |
continue
|
| 257 |
seen_violations.add(label)
|
|
|
|
| 18 |
# Configuration
|
| 19 |
# ==========================
|
| 20 |
CONFIG = {
|
| 21 |
+
"MODEL_PATH": "yolov8n.pt", # Lightweight model, ensure trained for violations
|
| 22 |
"OUTPUT_DIR": "static/output",
|
| 23 |
"VIOLATION_LABELS": {
|
| 24 |
0: "no_helmet",
|
|
|
|
| 33 |
"domain": "login"
|
| 34 |
},
|
| 35 |
"PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo1/resolve/main/static/output/",
|
| 36 |
+
"FRAME_SKIP": 15, # Process every 15th frame
|
| 37 |
+
"MAX_PROCESSING_TIME": 25, # Cap video processing at 25s
|
| 38 |
+
"CONFIDENCE_THRESHOLD": 0.5 # Minimum confidence for violation detection
|
| 39 |
}
|
| 40 |
|
| 41 |
# Setup logging
|
|
|
|
| 58 |
try:
|
| 59 |
model = YOLO(CONFIG["MODEL_PATH"]).to(device)
|
| 60 |
logger.info(f"Model loaded: {CONFIG['MODEL_PATH']}")
|
| 61 |
+
logger.warning("Ensure yolov8n.pt is trained to detect only 'no_helmet', 'no_harness', 'unsafe_posture'. Use custom yolov8_safety.pt if needed.")
|
| 62 |
return model
|
| 63 |
except Exception as e:
|
| 64 |
logger.error(f"Failed to load model: {e}")
|
|
|
|
| 252 |
for box in result.boxes:
|
| 253 |
cls, conf = int(box.cls), float(box.conf)
|
| 254 |
label = CONFIG["VIOLATION_LABELS"].get(cls, f"class_{cls}")
|
| 255 |
+
# Log unexpected classes for debugging
|
| 256 |
+
if label not in ["no_helmet", "no_harness", "unsafe_posture", "unsafe_zone"]:
|
| 257 |
+
logger.warning(f"Unexpected class detected: {label} (cls: {cls}, conf: {conf})")
|
| 258 |
+
continue
|
| 259 |
+
# Only process specified violations with sufficient confidence
|
| 260 |
if label not in ["no_helmet", "no_harness", "unsafe_posture"]:
|
| 261 |
continue
|
| 262 |
+
if conf < CONFIG["CONFIDENCE_THRESHOLD"]:
|
| 263 |
+
logger.info(f"Skipping low-confidence detection: {label} (conf: {conf})")
|
| 264 |
+
continue
|
| 265 |
if label in seen_violations:
|
| 266 |
continue
|
| 267 |
seen_violations.add(label)
|