PrashanthB461 commited on
Commit
f36c6cb
·
verified ·
1 Parent(s): ddb1bda

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -31
app.py CHANGED
@@ -19,7 +19,7 @@ from retrying import retry
19
  # ==========================
20
  CONFIG = {
21
  "MODEL_PATH": "yolov8_safety.pt",
22
- "FALLBACK_MODEL_PATH": "yolov8n.pt",
23
  "OUTPUT_DIR": "static/output",
24
  "VIOLATION_LABELS": {
25
  0: "no_helmet",
@@ -39,7 +39,7 @@ CONFIG = {
39
  },
40
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
41
  "FRAME_SKIP": 15,
42
- "MAX_PROCESSING_TIME": 30, # seconds
43
  "CONFIDENCE_THRESHOLD": 0.5
44
  }
45
 
@@ -54,14 +54,11 @@ logger.info(f"Using device: {device}")
54
 
55
  def load_model():
56
  try:
57
- model_path = CONFIG["MODEL_PATH"]
58
- if not os.path.exists(model_path):
59
- logger.warning(f"Custom model {model_path} not found. Falling back to {CONFIG['FALLBACK_MODEL_PATH']}")
60
- model_path = CONFIG["FALLBACK_MODEL_PATH"]
61
  model = YOLO(model_path).to(device)
62
  logger.info(f"Model loaded: {model_path}")
63
- if model_path == CONFIG["FALLBACK_MODEL_PATH"]:
64
- logger.warning("Using fallback model. Detection accuracy may be poor. Train yolov8_safety.pt for best results.")
65
  return model
66
  except Exception as e:
67
  logger.error(f"Failed to load model: {e}")
@@ -69,12 +66,15 @@ def load_model():
69
 
70
  model = load_model()
71
 
72
- @retry(stop_max_attempt_number=2, wait_fixed=1000)
 
 
 
73
  def connect_to_salesforce():
74
  try:
75
  sf = Salesforce(**CONFIG["SF_CREDENTIALS"])
76
  logger.info("Connected to Salesforce")
77
- sf.describe()
78
  return sf
79
  except Exception as e:
80
  logger.error(f"Salesforce connection failed: {e}")
@@ -129,7 +129,7 @@ def generate_violation_pdf(violations, score):
129
  logger.error(f"Error generating PDF: {e}")
130
  return "", "", None
131
 
132
- @retry(stop_max_attempt_number=2, wait_fixed=1000)
133
  def upload_pdf_to_salesforce(sf, pdf_file, report_id):
134
  try:
135
  if not pdf_file:
@@ -143,7 +143,7 @@ def upload_pdf_to_salesforce(sf, pdf_file, report_id):
143
  "FirstPublishLocationId": report_id
144
  }
145
  content_version = sf.ContentVersion.create(content_version_data)
146
- result = sf.query(f"SELECT Id FROM ContentVersion WHERE Id = '{content_version['id']}'")
147
  if not result['records']:
148
  logger.error("Failed to retrieve ContentVersion")
149
  return ""
@@ -154,7 +154,7 @@ def upload_pdf_to_salesforce(sf, pdf_file, report_id):
154
  logger.error(f"Error uploading PDF to Salesforce: {e}")
155
  return ""
156
 
157
- @retry(stop_max_attempt_number=2, wait_fixed=1000)
158
  def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
159
  try:
160
  sf = connect_to_salesforce()
@@ -173,10 +173,10 @@ def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
173
  }
174
  logger.info(f"Creating Salesforce record with data: {record_data}")
175
  try:
176
- record = sf.Safety_Violation_Report__c.create(record_data)
177
- logger.info(f"Created Safety_Violation_Report__c record: {record['id']}")
178
  except Exception as e:
179
- logger.error(f"Failed to create Safety_Violation_Report__c: {e}")
180
  record = sf.Account.create({"Name": f"Safety_Report_{int(time.time())}"})
181
  logger.warning(f"Fell back to Account record: {record['id']}")
182
  record_id = record["id"]
@@ -185,17 +185,17 @@ def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
185
  uploaded_url = upload_pdf_to_salesforce(sf, pdf_file, record_id)
186
  if uploaded_url:
187
  try:
188
- sf.Safety_Violation_Report__c.update(record_id, {"PDF_Report_URL__c": uploaded_url})
189
  logger.info(f"Updated record {record_id} with PDF URL: {uploaded_url}")
190
  except Exception as e:
191
- logger.error(f"Failed to update Safety_Violation_Report__c: {e}")
192
  sf.Account.update(record_id, {"Description": uploaded_url})
193
  logger.info(f"Updated Account record {record_id} with PDF URL")
194
  pdf_url = uploaded_url
195
 
196
  return record_id, pdf_url
197
  except Exception as e:
198
- logger.error(f"Salesforce record creation failed: {e}")
199
  return None, ""
200
 
201
  def calculate_safety_score(violations):
@@ -204,10 +204,7 @@ def calculate_safety_score(violations):
204
  "no_harness": 30,
205
  "unsafe_posture": 20
206
  }
207
- score = 100
208
- for v in violations:
209
- if v["violation"] in penalties:
210
- score -= penalties[v["violation"]]
211
  return max(score, 0)
212
 
213
  def process_video(video_data):
@@ -226,19 +223,19 @@ def process_video(video_data):
226
  start_time = time.time()
227
  fps = video.get(cv2.CAP_PROP_FPS)
228
 
229
- snapshot_taken = {"no_helmet": False, "no_harness": False, "unsafe_posture": False}
230
 
231
  while True:
232
  ret, frame = video.read()
233
  if not ret:
234
- break # End of video
235
 
236
  if frame_count % CONFIG["FRAME_SKIP"] != 0:
237
  frame_count += 1
238
  continue
239
 
240
  if time.time() - start_time > CONFIG["MAX_PROCESSING_TIME"]:
241
- logger.info("Processing time limit of 30 seconds reached")
242
  break
243
 
244
  results = model(frame, device=device)
@@ -247,7 +244,7 @@ def process_video(video_data):
247
  for box in result.boxes:
248
  cls, conf = int(box.cls), float(box.conf)
249
  label = CONFIG["VIOLATION_LABELS"].get(cls, None)
250
- if label not in ["no_helmet", "no_harness", "unsafe_posture"]:
251
  continue
252
  if conf < CONFIG["CONFIDENCE_THRESHOLD"]:
253
  continue
@@ -306,7 +303,7 @@ def process_video(video_data):
306
  "message": ""
307
  }
308
  except Exception as e:
309
- logger.error(f"Error processing video: {e}")
310
  return {
311
  "violations": [],
312
  "snapshots": [],
@@ -320,7 +317,6 @@ def gradio_interface(video_file):
320
  if not video_file:
321
  return "No file uploaded.", "", "No file uploaded.", "", ""
322
  try:
323
- # Show processing message early in the UI
324
  yield "Processing video... please wait.", "", "", "", ""
325
 
326
  with open(video_file, "rb") as f:
@@ -329,7 +325,6 @@ def gradio_interface(video_file):
329
  result = process_video(video_data)
330
 
331
  if result.get("message"):
332
- # If message present (either no violations or error), show it plainly
333
  yield result["message"], "", "", "", ""
334
  return
335
 
@@ -361,7 +356,7 @@ def gradio_interface(video_file):
361
  result["violation_details_url"] or "N/A"
362
  )
363
  except Exception as e:
364
- logger.error(f"Error in Gradio interface: {e}")
365
  yield f"Error: {str(e)}", "", "Error in processing.", "", ""
366
 
367
  interface = gr.Interface(
 
19
  # ==========================
20
  CONFIG = {
21
  "MODEL_PATH": "yolov8_safety.pt",
22
+ "FALLBACK_MODEL": "yolov8n.pt", # updated key to match first code
23
  "OUTPUT_DIR": "static/output",
24
  "VIOLATION_LABELS": {
25
  0: "no_helmet",
 
39
  },
40
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo2/resolve/main/static/output/",
41
  "FRAME_SKIP": 15,
42
+ "MAX_PROCESSING_TIME": 30,
43
  "CONFIDENCE_THRESHOLD": 0.5
44
  }
45
 
 
54
 
55
  def load_model():
56
  try:
57
+ model_path = CONFIG["MODEL_PATH"] if os.path.isfile(CONFIG["MODEL_PATH"]) else CONFIG["FALLBACK_MODEL"]
58
+ if model_path == CONFIG["FALLBACK_MODEL"]:
59
+ logger.warning("Using fallback model. Detection accuracy may be poor. Train yolov8_safety.pt for best results.")
 
60
  model = YOLO(model_path).to(device)
61
  logger.info(f"Model loaded: {model_path}")
 
 
62
  return model
63
  except Exception as e:
64
  logger.error(f"Failed to load model: {e}")
 
66
 
67
  model = load_model()
68
 
69
+ # ==========================
70
+ # Salesforce Integration
71
+ # ==========================
72
+ @retry(stop_max_attempt_number=3, wait_fixed=2000)
73
  def connect_to_salesforce():
74
  try:
75
  sf = Salesforce(**CONFIG["SF_CREDENTIALS"])
76
  logger.info("Connected to Salesforce")
77
+ sf.describe() # verify connection and metadata fetch
78
  return sf
79
  except Exception as e:
80
  logger.error(f"Salesforce connection failed: {e}")
 
129
  logger.error(f"Error generating PDF: {e}")
130
  return "", "", None
131
 
132
+ @retry(stop_max_attempt_number=3, wait_fixed=2000)
133
  def upload_pdf_to_salesforce(sf, pdf_file, report_id):
134
  try:
135
  if not pdf_file:
 
143
  "FirstPublishLocationId": report_id
144
  }
145
  content_version = sf.ContentVersion.create(content_version_data)
146
+ result = sf.query(f"SELECT Id, ContentDocumentId FROM ContentVersion WHERE Id = '{content_version['id']}'")
147
  if not result['records']:
148
  logger.error("Failed to retrieve ContentVersion")
149
  return ""
 
154
  logger.error(f"Error uploading PDF to Salesforce: {e}")
155
  return ""
156
 
157
+ @retry(stop_max_attempt_number=3, wait_fixed=2000)
158
  def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
159
  try:
160
  sf = connect_to_salesforce()
 
173
  }
174
  logger.info(f"Creating Salesforce record with data: {record_data}")
175
  try:
176
+ record = sf.Safety_Video_Report__c.create(record_data)
177
+ logger.info(f"Created Safety_Video_Report__c record: {record['id']}")
178
  except Exception as e:
179
+ logger.error(f"Failed to create Safety_Video_Report__c: {e}")
180
  record = sf.Account.create({"Name": f"Safety_Report_{int(time.time())}"})
181
  logger.warning(f"Fell back to Account record: {record['id']}")
182
  record_id = record["id"]
 
185
  uploaded_url = upload_pdf_to_salesforce(sf, pdf_file, record_id)
186
  if uploaded_url:
187
  try:
188
+ sf.Safety_Video_Report__c.update(record_id, {"PDF_Report_URL__c": uploaded_url})
189
  logger.info(f"Updated record {record_id} with PDF URL: {uploaded_url}")
190
  except Exception as e:
191
+ logger.error(f"Failed to update Safety_Video_Report__c: {e}")
192
  sf.Account.update(record_id, {"Description": uploaded_url})
193
  logger.info(f"Updated Account record {record_id} with PDF URL")
194
  pdf_url = uploaded_url
195
 
196
  return record_id, pdf_url
197
  except Exception as e:
198
+ logger.error(f"Salesforce record creation failed: {e}", exc_info=True)
199
  return None, ""
200
 
201
  def calculate_safety_score(violations):
 
204
  "no_harness": 30,
205
  "unsafe_posture": 20
206
  }
207
+ score = 100 - sum(penalties.get(v["violation"], 0) for v in violations)
 
 
 
208
  return max(score, 0)
209
 
210
  def process_video(video_data):
 
223
  start_time = time.time()
224
  fps = video.get(cv2.CAP_PROP_FPS)
225
 
226
+ snapshot_taken = {label: False for label in CONFIG["VIOLATION_LABELS"].values()}
227
 
228
  while True:
229
  ret, frame = video.read()
230
  if not ret:
231
+ break
232
 
233
  if frame_count % CONFIG["FRAME_SKIP"] != 0:
234
  frame_count += 1
235
  continue
236
 
237
  if time.time() - start_time > CONFIG["MAX_PROCESSING_TIME"]:
238
+ logger.info("Processing time limit reached")
239
  break
240
 
241
  results = model(frame, device=device)
 
244
  for box in result.boxes:
245
  cls, conf = int(box.cls), float(box.conf)
246
  label = CONFIG["VIOLATION_LABELS"].get(cls, None)
247
+ if label not in CONFIG["VIOLATION_LABELS"].values():
248
  continue
249
  if conf < CONFIG["CONFIDENCE_THRESHOLD"]:
250
  continue
 
303
  "message": ""
304
  }
305
  except Exception as e:
306
+ logger.error(f"Error processing video: {e}", exc_info=True)
307
  return {
308
  "violations": [],
309
  "snapshots": [],
 
317
  if not video_file:
318
  return "No file uploaded.", "", "No file uploaded.", "", ""
319
  try:
 
320
  yield "Processing video... please wait.", "", "", "", ""
321
 
322
  with open(video_file, "rb") as f:
 
325
  result = process_video(video_data)
326
 
327
  if result.get("message"):
 
328
  yield result["message"], "", "", "", ""
329
  return
330
 
 
356
  result["violation_details_url"] or "N/A"
357
  )
358
  except Exception as e:
359
+ logger.error(f"Error in Gradio interface: {e}", exc_info=True)
360
  yield f"Error: {str(e)}", "", "Error in processing.", "", ""
361
 
362
  interface = gr.Interface(