PrashanthB461 commited on
Commit
8050a10
·
verified ·
1 Parent(s): d3125e2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -3
app.py CHANGED
@@ -18,7 +18,7 @@ from retrying import retry
18
  # Configuration
19
  # ==========================
20
  CONFIG = {
21
- "MODEL_PATH": "yolov8n.pt", # Force lightweight Nano model
22
  "OUTPUT_DIR": "static/output",
23
  "VIOLATION_LABELS": {
24
  0: "no_helmet",
@@ -33,8 +33,9 @@ CONFIG = {
33
  "domain": "login"
34
  },
35
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo1/resolve/main/static/output/",
36
- "FRAME_SKIP": 15, # Increased to reduce frames processed
37
- "MAX_PROCESSING_TIME": 25 # Cap video processing at 25s to leave time for reporting
 
38
  }
39
 
40
  # Setup logging
@@ -57,6 +58,7 @@ def load_model():
57
  try:
58
  model = YOLO(CONFIG["MODEL_PATH"]).to(device)
59
  logger.info(f"Model loaded: {CONFIG['MODEL_PATH']}")
 
60
  return model
61
  except Exception as e:
62
  logger.error(f"Failed to load model: {e}")
@@ -250,8 +252,16 @@ def process_video(video_data):
250
  for box in result.boxes:
251
  cls, conf = int(box.cls), float(box.conf)
252
  label = CONFIG["VIOLATION_LABELS"].get(cls, f"class_{cls}")
 
 
 
 
 
253
  if label not in ["no_helmet", "no_harness", "unsafe_posture"]:
254
  continue
 
 
 
255
  if label in seen_violations:
256
  continue
257
  seen_violations.add(label)
 
18
  # Configuration
19
  # ==========================
20
  CONFIG = {
21
+ "MODEL_PATH": "yolov8n.pt", # Lightweight model, ensure trained for violations
22
  "OUTPUT_DIR": "static/output",
23
  "VIOLATION_LABELS": {
24
  0: "no_helmet",
 
33
  "domain": "login"
34
  },
35
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo1/resolve/main/static/output/",
36
+ "FRAME_SKIP": 15, # Process every 15th frame
37
+ "MAX_PROCESSING_TIME": 25, # Cap video processing at 25s
38
+ "CONFIDENCE_THRESHOLD": 0.5 # Minimum confidence for violation detection
39
  }
40
 
41
  # Setup logging
 
58
  try:
59
  model = YOLO(CONFIG["MODEL_PATH"]).to(device)
60
  logger.info(f"Model loaded: {CONFIG['MODEL_PATH']}")
61
+ logger.warning("Ensure yolov8n.pt is trained to detect only 'no_helmet', 'no_harness', 'unsafe_posture'. Use custom yolov8_safety.pt if needed.")
62
  return model
63
  except Exception as e:
64
  logger.error(f"Failed to load model: {e}")
 
252
  for box in result.boxes:
253
  cls, conf = int(box.cls), float(box.conf)
254
  label = CONFIG["VIOLATION_LABELS"].get(cls, f"class_{cls}")
255
+ # Log unexpected classes for debugging
256
+ if label not in ["no_helmet", "no_harness", "unsafe_posture", "unsafe_zone"]:
257
+ logger.warning(f"Unexpected class detected: {label} (cls: {cls}, conf: {conf})")
258
+ continue
259
+ # Only process specified violations with sufficient confidence
260
  if label not in ["no_helmet", "no_harness", "unsafe_posture"]:
261
  continue
262
+ if conf < CONFIG["CONFIDENCE_THRESHOLD"]:
263
+ logger.info(f"Skipping low-confidence detection: {label} (conf: {conf})")
264
+ continue
265
  if label in seen_violations:
266
  continue
267
  seen_violations.add(label)