Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -35,7 +35,7 @@ logger = logging.getLogger(__name__)
|
|
| 35 |
# ByteTrack Implementation
|
| 36 |
# ==========================
|
| 37 |
class BYTETracker:
|
| 38 |
-
"""
|
| 39 |
def __init__(self, track_thresh=0.5, track_buffer=30, match_thresh=0.8, frame_rate=30):
|
| 40 |
self.track_thresh = track_thresh
|
| 41 |
self.track_buffer = track_buffer
|
|
@@ -112,12 +112,7 @@ CONFIG = {
|
|
| 112 |
"MATCH_THRESH": 0.8
|
| 113 |
}
|
| 114 |
|
| 115 |
-
#
|
| 116 |
-
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
| 117 |
-
logger = logging.getLogger(__name__)
|
| 118 |
-
|
| 119 |
-
os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True)
|
| 120 |
-
|
| 121 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 122 |
logger.info(f"Using device: {device}")
|
| 123 |
|
|
@@ -140,12 +135,8 @@ def load_model():
|
|
| 140 |
|
| 141 |
model = load_model()
|
| 142 |
|
| 143 |
-
# [Rest of your existing functions remain exactly the same...]
|
| 144 |
-
# draw_detections(), calculate_safety_score(), generate_violation_pdf(),
|
| 145 |
-
# connect_to_salesforce(), upload_pdf_to_salesforce(), push_report_to_salesforce(),
|
| 146 |
-
# process_video(), and gradio_interface() functions should be kept exactly as they were
|
| 147 |
# ==========================
|
| 148 |
-
#
|
| 149 |
# ==========================
|
| 150 |
def draw_detections(frame, detections):
|
| 151 |
for det in detections:
|
|
@@ -226,9 +217,6 @@ def generate_violation_pdf(violations, score):
|
|
| 226 |
logger.error(f"Error generating PDF: {e}")
|
| 227 |
return "", "", None
|
| 228 |
|
| 229 |
-
# ==========================
|
| 230 |
-
# Salesforce Integration
|
| 231 |
-
# ==========================
|
| 232 |
@retry(stop_max_attempt_number=3, wait_fixed=2000)
|
| 233 |
def connect_to_salesforce():
|
| 234 |
try:
|
|
@@ -307,9 +295,6 @@ def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
|
|
| 307 |
logger.error(f"Salesforce record creation failed: {e}", exc_info=True)
|
| 308 |
return None, ""
|
| 309 |
|
| 310 |
-
# ==========================
|
| 311 |
-
# Fast Video Processing
|
| 312 |
-
# ==========================
|
| 313 |
def process_video(video_data):
|
| 314 |
try:
|
| 315 |
# Create temp video file
|
|
@@ -412,8 +397,8 @@ def process_video(video_data):
|
|
| 412 |
|
| 413 |
# Process tracked objects
|
| 414 |
for obj, track_input in zip(tracked_objects, track_inputs):
|
| 415 |
-
worker_id = obj
|
| 416 |
-
label = CONFIG["VIOLATION_LABELS"].get(int(obj
|
| 417 |
bbox = track_input["bbox"]
|
| 418 |
conf = track_input["conf"]
|
| 419 |
|
|
@@ -499,55 +484,6 @@ def process_video(video_data):
|
|
| 499 |
logger.error(f"Error processing video: {e}", exc_info=True)
|
| 500 |
yield f"Error processing video: {e}", "", "", "", ""
|
| 501 |
|
| 502 |
-
|
| 503 |
-
# Initialize device and model
|
| 504 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 505 |
-
logger.info(f"Using device: {device}")
|
| 506 |
-
|
| 507 |
-
def load_model():
|
| 508 |
-
try:
|
| 509 |
-
if os.path.isfile(CONFIG["MODEL_PATH"]):
|
| 510 |
-
model_path = CONFIG["MODEL_PATH"]
|
| 511 |
-
logger.info(f"Model loaded: {model_path}")
|
| 512 |
-
else:
|
| 513 |
-
model_path = CONFIG["FALLBACK_MODEL"]
|
| 514 |
-
logger.warning("Using fallback model. Train yolov8_safety.pt for best results.")
|
| 515 |
-
if not os.path.isfile(model_path):
|
| 516 |
-
logger.info(f"Downloading fallback model: {model_path}")
|
| 517 |
-
torch.hub.download_url_to_file('https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt', model_path)
|
| 518 |
-
model = YOLO(model_path).to(device)
|
| 519 |
-
return model
|
| 520 |
-
except Exception as e:
|
| 521 |
-
logger.error(f"Failed to load model: {e}")
|
| 522 |
-
raise
|
| 523 |
-
|
| 524 |
-
model = load_model()
|
| 525 |
-
|
| 526 |
-
# ==========================
|
| 527 |
-
# Helper Functions
|
| 528 |
-
# ==========================
|
| 529 |
-
def draw_detections(frame, detections):
|
| 530 |
-
# ... [your existing implementation] ...
|
| 531 |
-
|
| 532 |
-
def calculate_safety_score(violations):
|
| 533 |
-
# ... [your existing implementation] ...
|
| 534 |
-
|
| 535 |
-
def generate_violation_pdf(violations, score):
|
| 536 |
-
# ... [your existing implementation] ...
|
| 537 |
-
|
| 538 |
-
@retry(stop_max_attempt_number=3, wait_fixed=2000)
|
| 539 |
-
def connect_to_salesforce():
|
| 540 |
-
# ... [your existing implementation] ...
|
| 541 |
-
|
| 542 |
-
def upload_pdf_to_salesforce(sf, pdf_file, report_id):
|
| 543 |
-
# ... [your existing implementation] ...
|
| 544 |
-
|
| 545 |
-
def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
|
| 546 |
-
# ... [your existing implementation] ...
|
| 547 |
-
|
| 548 |
-
def process_video(video_data):
|
| 549 |
-
# ... [your existing implementation] ...
|
| 550 |
-
|
| 551 |
def gradio_interface(video_file):
|
| 552 |
if not video_file:
|
| 553 |
return "No file uploaded.", "", "No file uploaded.", "", ""
|
|
|
|
| 35 |
# ByteTrack Implementation
|
| 36 |
# ==========================
|
| 37 |
class BYTETracker:
|
| 38 |
+
"""Robust ByteTrack implementation with fallback"""
|
| 39 |
def __init__(self, track_thresh=0.5, track_buffer=30, match_thresh=0.8, frame_rate=30):
|
| 40 |
self.track_thresh = track_thresh
|
| 41 |
self.track_buffer = track_buffer
|
|
|
|
| 112 |
"MATCH_THRESH": 0.8
|
| 113 |
}
|
| 114 |
|
| 115 |
+
# Initialize device and model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 117 |
logger.info(f"Using device: {device}")
|
| 118 |
|
|
|
|
| 135 |
|
| 136 |
model = load_model()
|
| 137 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
# ==========================
|
| 139 |
+
# Helper Functions
|
| 140 |
# ==========================
|
| 141 |
def draw_detections(frame, detections):
|
| 142 |
for det in detections:
|
|
|
|
| 217 |
logger.error(f"Error generating PDF: {e}")
|
| 218 |
return "", "", None
|
| 219 |
|
|
|
|
|
|
|
|
|
|
| 220 |
@retry(stop_max_attempt_number=3, wait_fixed=2000)
|
| 221 |
def connect_to_salesforce():
|
| 222 |
try:
|
|
|
|
| 295 |
logger.error(f"Salesforce record creation failed: {e}", exc_info=True)
|
| 296 |
return None, ""
|
| 297 |
|
|
|
|
|
|
|
|
|
|
| 298 |
def process_video(video_data):
|
| 299 |
try:
|
| 300 |
# Create temp video file
|
|
|
|
| 397 |
|
| 398 |
# Process tracked objects
|
| 399 |
for obj, track_input in zip(tracked_objects, track_inputs):
|
| 400 |
+
worker_id = obj['id']
|
| 401 |
+
label = CONFIG["VIOLATION_LABELS"].get(int(obj['cls']), None)
|
| 402 |
bbox = track_input["bbox"]
|
| 403 |
conf = track_input["conf"]
|
| 404 |
|
|
|
|
| 484 |
logger.error(f"Error processing video: {e}", exc_info=True)
|
| 485 |
yield f"Error processing video: {e}", "", "", "", ""
|
| 486 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 487 |
def gradio_interface(video_file):
|
| 488 |
if not video_file:
|
| 489 |
return "No file uploaded.", "", "No file uploaded.", "", ""
|