File size: 15,519 Bytes
dc66d57
 
 
 
efdcd31
dc66d57
 
d6b1114
8223186
 
 
 
 
f225f91
9a1b6ee
8a4e253
04fdfdf
dc66d57
 
 
f225f91
97fe32f
f225f91
 
 
 
de9b54b
8a4e253
2d7e132
 
 
 
 
f225f91
 
 
 
 
 
 
 
8050a10
 
 
f225f91
bcf309d
f225f91
d3125e2
37f7ef0
08e92c3
f225f91
 
71a216e
dc66d57
 
 
08e92c3
f225f91
71a216e
dc66d57
f225f91
dc66d57
f225f91
 
d3125e2
 
2d7e132
f225f91
 
 
 
 
 
e04e491
0f434be
f225f91
0f434be
d3125e2
0f434be
55120e4
f225f91
 
cc1585c
55120e4
 
f225f91
55120e4
0f434be
8223186
d6d0b51
8223186
f225f91
8223186
 
 
 
 
 
f225f91
8223186
 
 
 
 
 
 
 
 
 
 
 
8a4e253
 
 
 
 
 
 
 
 
 
 
 
8223186
 
 
 
 
 
 
f225f91
 
8223186
 
f225f91
8223186
 
d3125e2
f225f91
8223186
 
f225f91
8223186
f225f91
8223186
 
 
f225f91
8223186
 
 
d3125e2
8223186
f225f91
8223186
f225f91
 
8223186
d6d0b51
f225f91
8223186
d6b1114
d3125e2
f225f91
55120e4
 
f225f91
2d7e132
f225f91
 
 
0f434be
55120e4
9a1b6ee
 
d846620
 
 
 
d3125e2
9a1b6ee
 
d3125e2
9a1b6ee
 
 
d3125e2
d846620
8223186
d846620
1f0ced7
d846620
9a1b6ee
 
d3125e2
9a1b6ee
 
 
d3125e2
d846620
8223186
d846620
 
d3125e2
d846620
0f434be
 
 
 
de9b54b
0f434be
de9b54b
 
 
0f434be
60028e1
 
de9b54b
 
 
0f434be
dc66d57
 
 
f225f91
08e92c3
f225f91
08e92c3
 
f225f91
08e92c3
 
 
937ffb2
08e92c3
f225f91
969dec6
d3125e2
de9b54b
d3125e2
 
 
 
08e92c3
969dec6
08e92c3
de9b54b
08e92c3
 
de9b54b
 
 
 
d3125e2
 
 
 
 
08e92c3
60028e1
08e92c3
 
f225f91
2d7e132
97fe32f
de9b54b
8a4e253
de9b54b
97fe32f
8050a10
 
 
de9b54b
cc1585c
06a63ac
 
08e92c3
 
 
 
06a63ac
de9b54b
08e92c3
 
 
d3125e2
 
 
 
 
 
 
 
 
 
 
 
 
dc66d57
 
2448900
08e92c3
 
de9b54b
 
d3125e2
de9b54b
 
 
 
 
 
 
 
 
8223186
f225f91
08e92c3
 
 
 
 
55120e4
de9b54b
08e92c3
 
de9b54b
08e92c3
 
 
de9b54b
0f434be
de9b54b
08e92c3
e04e491
efdcd31
 
dc66d57
937ffb2
f225f91
de9b54b
08e92c3
 
 
 
1a7f512
 
 
2d7e132
 
1a7f512
 
2d7e132
 
1a7f512
d846620
1a7f512
 
 
 
2d7e132
1a7f512
 
 
08e92c3
1a7f512
f225f91
5588019
f225f91
de9b54b
08e92c3
 
d3125e2
de9b54b
efdcd31
d1c5c67
08e92c3
de9b54b
08e92c3
1a7f512
08e92c3
1a7f512
55120e4
de9b54b
08e92c3
 
2d7e132
efdcd31
 
37f7ef0
f225f91
1a937c7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
import os
import cv2
import gradio as gr
import torch
import numpy as np
from ultralytics import YOLO
import time
from simple_salesforce import Salesforce
from reportlab.lib.pagesizes import letter
from reportlab.pdfgen import canvas
from reportlab.lib.units import inch
from io import BytesIO
import base64
import logging
from retrying import retry
import uuid

# ==========================
# Configuration
# ==========================
CONFIG = {
    "MODEL_PATH": "yolov8n.pt",  # Lightweight model, must be trained for violations only
    "OUTPUT_DIR": "static/output",
    "VIOLATION_LABELS": {
        0: "no_helmet",
        1: "no_harness",
        2: "unsafe_posture",
        3: "unsafe_zone"  # Ignored in processing
    },
    "DISPLAY_NAMES": {  # Mapping for user-friendly violation names
        "no_helmet": "Missing Helmet",
        "no_harness": "Missing Harness",
        "unsafe_posture": "Unsafe Posture"
    },
    "SF_CREDENTIALS": {
        "username": "prashanth1ai@safety.com",
        "password": "SaiPrash461",
        "security_token": "AP4AQnPoidIKPvSvNEfAHyoK",
        "domain": "login"
    },
    "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo1/resolve/main/static/output/",
    "FRAME_SKIP": 15,  # Process every 15th frame
    "MAX_PROCESSING_TIME": 25,  # Cap video processing at 25s
    "CONFIDENCE_THRESHOLD": 0.5  # Minimum confidence for violation detection
}

# Setup logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)

# Ensure output directory exists
os.makedirs(CONFIG["OUTPUT_DIR"], exist_ok=True)

# ==========================
# Device Setup
# ==========================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")

# ==========================
# Model Loading
# ==========================
def load_model():
    try:
        model = YOLO(CONFIG["MODEL_PATH"]).to(device)
        logger.info(f"Model loaded: {CONFIG['MODEL_PATH']}")
        logger.warning("Ensure yolov8n.pt is trained to detect ONLY 'no_helmet', 'no_harness', 'unsafe_posture'. Replace with custom-trained yolov8_safety.pt if unexpected classes are detected.")
        return model
    except Exception as e:
        logger.error(f"Failed to load model: {e}")
        raise

model = load_model()

# ==========================
# Salesforce Integration
# ==========================
@retry(stop_max_attempt_number=2, wait_fixed=1000)
def connect_to_salesforce():
    try:
        sf = Salesforce(**CONFIG["SF_CREDENTIALS"])
        logger.info("Connected to Salesforce")
        sf.describe()
        return sf
    except Exception as e:
        logger.error(f"Salesforce connection failed: {e}")
        raise

def generate_violation_pdf(violations, score):
    try:
        pdf_filename = f"violations_{int(time.time())}.pdf"
        pdf_path = os.path.join(CONFIG["OUTPUT_DIR"], pdf_filename)
        pdf_file = BytesIO()
        c = canvas.Canvas(pdf_file, pagesize=letter)
        c.setFont("Helvetica", 12)
        c.drawString(1 * inch, 10 * inch, "Worksite Safety Violation Report")
        c.setFont("Helvetica", 10)

        y_position = 9.5 * inch
        report_data = {
            "Compliance Score": f"{score}%",
            "Violations Found": len(violations),
            "Timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
        }
        for key, value in report_data.items():
            c.drawString(1 * inch, y_position, f"{key}: {value}")
            y_position -= 0.3 * inch

        y_position -= 0.3 * inch
        c.drawString(1 * inch, y_position, "Violation Details:")
        y_position -= 0.3 * inch
        if not violations:
            c.drawString(1 * inch, y_position, "No violations detected.")
        else:
            for v in violations:
                display_name = CONFIG["DISPLAY_NAMES"].get(v["violation"], v["violation"])
                text = f"{display_name} at {v['timestamp']:.2f}s (Confidence: {v['confidence']})"
                c.drawString(1 * inch, y_position, text)
                y_position -= 0.3 * inch
                if y_position < 1 * inch:
                    c.showPage()
                    c.setFont("Helvetica", 10)
                    y_position = 10 * inch

        c.showPage()
        c.save()
        pdf_file.seek(0)

        with open(pdf_path, "wb") as f:
            f.write(pdf_file.getvalue())
        public_url = f"{CONFIG['PUBLIC_URL_BASE']}{pdf_filename}"
        logger.info(f"PDF generated: {public_url}")
        return pdf_path, public_url, pdf_file
    except Exception as e:
        logger.error(f"Error generating PDF: {e}")
        return "", "", None

@retry(stop_max_attempt_number=2, wait_fixed=1000)
def upload_pdf_to_salesforce(sf, pdf_file, report_id):
    try:
        if not pdf_file:
            logger.error("No PDF file provided for upload")
            return ""
        encoded_pdf = base64.b64encode(pdf_file.getvalue()).decode('utf-8')
        content_version_data = {
            "Title": f"Safety_Violation_Report_{int(time.time())}",
            "PathOnClient": f"safety_violation_{int(time.time())}.pdf",
            "VersionData": encoded_pdf,
            "FirstPublishLocationId": report_id
        }
        content_version = sf.ContentVersion.create(content_version_data)
        result = sf.query(f"SELECT Id FROM ContentVersion WHERE Id = '{content_version['id']}'")
        if not result['records']:
            logger.error("Failed to retrieve ContentVersion")
            return ""
        file_url = f"https://{sf.sf_instance}/sfc/servlet.shepherd/version/download/{content_version['id']}"
        logger.info(f"PDF uploaded to Salesforce: {file_url}")
        return file_url
    except Exception as e:
        logger.error(f"Error uploading PDF to Salesforce: {e}")
        return ""

@retry(stop_max_attempt_number=2, wait_fixed=1000)
def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
    try:
        sf = connect_to_salesforce()
        violations_text = "\n".join(
            f"{CONFIG['DISPLAY_NAMES'].get(v['violation'], v['violation'])} at {v['timestamp']:.2f}s (Confidence: {v['confidence']})"
            for v in violations
        ) or "No violations detected."
        pdf_url = f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(pdf_path)}" if pdf_path else ""

        record_data = {
            "Compliance_Score__c": score,
            "Violations_Found__c": len(violations),
            "Violations_Details__c": violations_text,
            "Status__c": "Pending",
            "PDF_Report_URL__c": pdf_url
        }
        logger.info(f"Creating Salesforce record with data: {record_data}")
        try:
            record = sf.Safety_Video_Report__c.create(record_data)
            logger.info(f"Created Safety_Video_Report__c record: {record['id']}")
        except Exception as e:
            logger.error(f"Failed to create Safety_Video_Report__c: {e}")
            record = sf.Account.create({"Name": f"Safety_Report_{int(time.time())}"})
            logger.warning(f"Fell back to Account record: {record['id']}")
        record_id = record["id"]

        if pdf_file:
            uploaded_url = upload_pdf_to_salesforce(sf, pdf_file, record_id)
            if uploaded_url:
                try:
                    sf.Safety_Video_Report__c.update(record_id, {"PDF_Report_URL__c": uploaded_url})
                    logger.info(f"Updated record {record_id} with PDF URL: {uploaded_url}")
                except Exception as e:
                    logger.error(f"Failed to update Safety_Video_Report__c: {e}")
                    sf.Account.update(record_id, {"Description": uploaded_url})
                    logger.info(f"Updated Account record {record_id} with PDF URL")
                pdf_url = uploaded_url

        return record_id, pdf_url
    except Exception as e:
        logger.error(f"Salesforce record creation failed: {e}")
        return None, ""

# ==========================
# Safety Score Calculation
# ==========================
def calculate_safety_score(violations):
    penalties = {
        "no_helmet": 25,
        "no_harness": 30,
        "unsafe_posture": 20
    }
    score = 100
    for v in violations:
        if v["violation"] in penalties:
            score -= penalties[v["violation"]]
    return max(score, 0)

# ==========================
# Video Processing
# ==========================
def process_video(video_data):
    try:
        video_path = os.path.join(CONFIG["OUTPUT_DIR"], f"temp_{int(time.time())}.mp4")
        with open(video_path, "wb") as f:
            f.write(video_data)
        logger.info(f"Video saved: {video_path}")

        video = cv2.VideoCapture(video_path)
        if not video.isOpened():
            raise ValueError("Could not open video file")

        violations, snapshots = [], []
        frame_count = 0
        start_time = time.time()
        fps = video.get(cv2.CAP_PROP_FPS)
        max_frames = int(60 * fps)  # Process up to 1 minute

        # Track one snapshot per violation type
        snapshot_taken = {"no_helmet": False, "no_harness": False, "unsafe_posture": False}

        while True:
            ret, frame = video.read()
            if not ret or frame_count >= max_frames:
                break

            if frame_count % CONFIG["FRAME_SKIP"] != 0:
                frame_count += 1
                continue

            # Stop if processing time exceeds 25 seconds
            if time.time() - start_time > CONFIG["MAX_PROCESSING_TIME"]:
                logger.info("Processing time limit reached")
                break

            results = model(frame, device=device)
            seen_violations = set()
            for result in results:
                for box in result.boxes:
                    cls, conf = int(box.cls), float(box.conf)
                    label = CONFIG["VIOLATION_LABELS"].get(cls, f"unknown_class_{cls}")
                    # Only process specified violations
                    if label not in ["no_helmet", "no_harness", "unsafe_posture"]:
                        logger.info(f"Ignoring detection: {label} (cls: {cls}, conf: {conf}) - not a target violation")
                        continue
                    # Apply confidence threshold
                    if conf < CONFIG["CONFIDENCE_THRESHOLD"]:
                        logger.info(f"Skipping low-confidence detection: {label} (conf: {conf})")
                        continue
                    if label in seen_violations:
                        continue
                    seen_violations.add(label)

                    violation = {
                        "frame": frame_count,
                        "violation": label,
                        "confidence": round(conf, 2),
                        "bounding_box": [round(x, 2) for x in box.xywh.cpu().numpy()[0]],
                        "timestamp": frame_count / fps
                    }
                    violations.append(violation)

                    # Save only one snapshot per violation type
                    if not snapshot_taken[label]:
                        snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], f"snapshot_{frame_count}_{label}.jpg")
                        cv2.imwrite(snapshot_path, frame)
                        with open(snapshot_path, "rb") as img_file:
                            img_base64 = base64.b64encode(img_file.read()).decode('utf-8')
                        snapshots.append({
                            "violation": label,
                            "frame": frame_count,
                            "snapshot_url": f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(snapshot_path)}",
                            "snapshot_base64": f"data:image/jpeg;base64,{img_base64}"
                        })
                        snapshot_taken[label] = True

            frame_count += 1

        video.release()
        os.remove(video_path)

        if not violations:
            logger.info("No violations detected")
            return {
                "violations": [],
                "snapshots": [],
                "score": 100,
                "salesforce_record_id": None,
                "violation_details_url": ""
            }

        score = calculate_safety_score(violations)
        pdf_path, pdf_url, pdf_file = generate_violation_pdf(violations, score)
        report_id, final_pdf_url = push_report_to_salesforce(violations, score, pdf_path, pdf_file)

        return {
            "violations": violations,
            "snapshots": snapshots,
            "score": score,
            "salesforce_record_id": report_id,
            "violation_details_url": final_pdf_url
        }
    except Exception as e:
        logger.error(f"Error processing video: {e}")
        return {
            "violations": [],
            "snapshots": [],
            "score": 100,
            "salesforce_record_id": None,
            "violation_details_url": ""
        }

# ==========================
# Gradio Interface
# ==========================
def gradio_interface(video_file):
    if not video_file:
        return "No file uploaded.", "", "No file uploaded.", "", ""
    try:
        with open(video_file, "rb") as f:
            video_data = f.read()
        result = process_video(video_data)

        violation_table = "No violations detected."
        if result["violations"]:
            header = "| Violation        | Timestamp | Confidence | Bounding Box             | Violation Details       |\n"
            separator = "|------------------|-----------|------------|--------------------------|-------------------------|\n"
            rows = []
            for v in result["violations"]:
                display_name = CONFIG["DISPLAY_NAMES"].get(v["violation"], v["violation"])
                row = f"| {display_name:<16} | {v['timestamp']:.2f}s  | {v['confidence']:.2f}      | {v['bounding_box']} | {result['violation_details_url']} |"
                rows.append(row)
            violation_table = header + separator + "\n".join(rows)

        snapshots_text = "No snapshots captured."
        if result["snapshots"]:
            snapshots_text = "\n".join(
                f"- Snapshot for {CONFIG['DISPLAY_NAMES'].get(s['violation'], s['violation'])} at frame {s['frame']}: ![]({s['snapshot_base64']})"
                for s in result["snapshots"]
            )

        return (
            violation_table,
            f"Safety Score: {result['score']}%",
            snapshots_text,
            f"Salesforce Record ID: {result['salesforce_record_id'] or 'N/A'}",
            result["violation_details_url"] or "N/A"
        )
    except Exception as e:
        logger.error(f"Error in Gradio interface: {e}")
        return f"Error: {str(e)}", "", "Error in processing.", "", ""

interface = gr.Interface(
    fn=gradio_interface,
    inputs=gr.Video(label="Upload Site Video"),
    outputs=[
        gr.Markdown(label="Detected Safety Violations"),
        gr.Textbox(label="Compliance Score"),
        gr.Markdown(label="Snapshots"),
        gr.Textbox(label="Salesforce Record ID"),
        gr.Textbox(label="Violation Details URL")
    ],
    title="Worksite Safety Violation Analyzer",
    description="Upload site videos to detect safety violations (Missing Helmet, Missing Harness, Unsafe Posture). Non-violations are ignored."
)

if __name__ == "__main__":
    logger.info("Launching Safety Analyzer App...")
    interface.launch()