Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -37,50 +37,28 @@ logger = logging.getLogger(__name__)
|
|
| 37 |
class BYTETracker:
|
| 38 |
"""Custom implementation of ByteTrack to avoid installation issues"""
|
| 39 |
def __init__(self, track_thresh=0.5, track_buffer=30, match_thresh=0.8, frame_rate=30):
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
match_thresh=match_thresh,
|
| 46 |
-
frame_rate=frame_rate
|
| 47 |
-
)
|
| 48 |
-
self._original = True
|
| 49 |
-
except ImportError:
|
| 50 |
-
logger.warning("Using simplified ByteTrack implementation")
|
| 51 |
-
self._original = False
|
| 52 |
-
self.track_thresh = track_thresh
|
| 53 |
-
self.track_buffer = track_buffer
|
| 54 |
-
self.match_thresh = match_thresh
|
| 55 |
-
self.frame_rate = frame_rate
|
| 56 |
-
self.tracked_objects = {}
|
| 57 |
-
self.next_id = 1
|
| 58 |
-
|
| 59 |
-
def update(self, dets, scores, cls):
|
| 60 |
-
if self._original:
|
| 61 |
-
return self.tracker.update(dets, scores, cls)
|
| 62 |
-
|
| 63 |
-
# Simplified tracking logic for fallback
|
| 64 |
-
if len(dets) == 0:
|
| 65 |
-
return []
|
| 66 |
|
|
|
|
| 67 |
tracks = []
|
| 68 |
for i, (det, score, cl) in enumerate(zip(dets, scores, cls)):
|
| 69 |
if score < self.track_thresh:
|
| 70 |
continue
|
| 71 |
|
| 72 |
x, y, w, h = det
|
| 73 |
-
track_id = self.next_id
|
| 74 |
-
self.next_id += 1
|
| 75 |
tracks.append({
|
| 76 |
-
'id':
|
| 77 |
'bbox': [x, y, w, h],
|
| 78 |
'score': score,
|
| 79 |
'cls': cl
|
| 80 |
})
|
|
|
|
| 81 |
return tracks
|
| 82 |
|
| 83 |
-
|
| 84 |
# ==========================
|
| 85 |
# Optimized Configuration
|
| 86 |
# ==========================
|
|
@@ -522,6 +500,67 @@ def process_video(video_data):
|
|
| 522 |
yield f"Error processing video: {e}", "", "", "", ""
|
| 523 |
|
| 524 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 525 |
# ==========================
|
| 526 |
# Gradio Interface
|
| 527 |
# ==========================
|
|
|
|
| 37 |
class BYTETracker:
|
| 38 |
"""Custom implementation of ByteTrack to avoid installation issues"""
|
| 39 |
def __init__(self, track_thresh=0.5, track_buffer=30, match_thresh=0.8, frame_rate=30):
|
| 40 |
+
self.track_thresh = track_thresh
|
| 41 |
+
self.track_buffer = track_buffer
|
| 42 |
+
self.match_thresh = match_thresh
|
| 43 |
+
self.frame_rate = frame_rate
|
| 44 |
+
self.next_id = 1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
+
def update(self, dets, scores, cls):
|
| 47 |
tracks = []
|
| 48 |
for i, (det, score, cl) in enumerate(zip(dets, scores, cls)):
|
| 49 |
if score < self.track_thresh:
|
| 50 |
continue
|
| 51 |
|
| 52 |
x, y, w, h = det
|
|
|
|
|
|
|
| 53 |
tracks.append({
|
| 54 |
+
'id': self.next_id,
|
| 55 |
'bbox': [x, y, w, h],
|
| 56 |
'score': score,
|
| 57 |
'cls': cl
|
| 58 |
})
|
| 59 |
+
self.next_id += 1
|
| 60 |
return tracks
|
| 61 |
|
|
|
|
| 62 |
# ==========================
|
| 63 |
# Optimized Configuration
|
| 64 |
# ==========================
|
|
|
|
| 500 |
yield f"Error processing video: {e}", "", "", "", ""
|
| 501 |
|
| 502 |
|
| 503 |
+
# Initialize device and model
|
| 504 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 505 |
+
logger.info(f"Using device: {device}")
|
| 506 |
+
|
| 507 |
+
def load_model():
|
| 508 |
+
try:
|
| 509 |
+
if os.path.isfile(CONFIG["MODEL_PATH"]):
|
| 510 |
+
model_path = CONFIG["MODEL_PATH"]
|
| 511 |
+
logger.info(f"Model loaded: {model_path}")
|
| 512 |
+
else:
|
| 513 |
+
model_path = CONFIG["FALLBACK_MODEL"]
|
| 514 |
+
logger.warning("Using fallback model. Train yolov8_safety.pt for best results.")
|
| 515 |
+
if not os.path.isfile(model_path):
|
| 516 |
+
logger.info(f"Downloading fallback model: {model_path}")
|
| 517 |
+
torch.hub.download_url_to_file('https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt', model_path)
|
| 518 |
+
model = YOLO(model_path).to(device)
|
| 519 |
+
return model
|
| 520 |
+
except Exception as e:
|
| 521 |
+
logger.error(f"Failed to load model: {e}")
|
| 522 |
+
raise
|
| 523 |
+
|
| 524 |
+
model = load_model()
|
| 525 |
+
|
| 526 |
+
# ==========================
|
| 527 |
+
# Helper Functions
|
| 528 |
+
# ==========================
|
| 529 |
+
def draw_detections(frame, detections):
|
| 530 |
+
# ... [your existing implementation] ...
|
| 531 |
+
|
| 532 |
+
def calculate_safety_score(violations):
|
| 533 |
+
# ... [your existing implementation] ...
|
| 534 |
+
|
| 535 |
+
def generate_violation_pdf(violations, score):
|
| 536 |
+
# ... [your existing implementation] ...
|
| 537 |
+
|
| 538 |
+
@retry(stop_max_attempt_number=3, wait_fixed=2000)
|
| 539 |
+
def connect_to_salesforce():
|
| 540 |
+
# ... [your existing implementation] ...
|
| 541 |
+
|
| 542 |
+
def upload_pdf_to_salesforce(sf, pdf_file, report_id):
|
| 543 |
+
# ... [your existing implementation] ...
|
| 544 |
+
|
| 545 |
+
def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
|
| 546 |
+
# ... [your existing implementation] ...
|
| 547 |
+
|
| 548 |
+
def process_video(video_data):
|
| 549 |
+
# ... [your existing implementation] ...
|
| 550 |
+
|
| 551 |
+
def gradio_interface(video_file):
|
| 552 |
+
if not video_file:
|
| 553 |
+
return "No file uploaded.", "", "No file uploaded.", "", ""
|
| 554 |
+
try:
|
| 555 |
+
with open(video_file, "rb") as f:
|
| 556 |
+
video_data = f.read()
|
| 557 |
+
|
| 558 |
+
for status, score, snapshots_text, record_id, details_url in process_video(video_data):
|
| 559 |
+
yield status, score, snapshots_text, record_id, details_url
|
| 560 |
+
except Exception as e:
|
| 561 |
+
logger.error(f"Error in Gradio interface: {e}", exc_info=True)
|
| 562 |
+
yield f"Error: {str(e)}", "", "Error in processing.", "", ""
|
| 563 |
+
|
| 564 |
# ==========================
|
| 565 |
# Gradio Interface
|
| 566 |
# ==========================
|