PrashanthB461 commited on
Commit
d3125e2
·
verified ·
1 Parent(s): 1f0ced7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -34
app.py CHANGED
@@ -18,8 +18,7 @@ from retrying import retry
18
  # Configuration
19
  # ==========================
20
  CONFIG = {
21
- "MODEL_PATH": os.getenv("SAFETY_MODEL_PATH", "models/yolov8_safety.pt"),
22
- "FALLBACK_MODEL": "yolov8n.pt",
23
  "OUTPUT_DIR": "static/output",
24
  "VIOLATION_LABELS": {
25
  0: "no_helmet",
@@ -34,12 +33,12 @@ CONFIG = {
34
  "domain": "login"
35
  },
36
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo1/resolve/main/static/output/",
37
- "FRAME_SKIP": 5,
38
- "MAX_PROCESSING_TIME": 60 # Updated to 60 seconds
39
  }
40
 
41
  # Setup logging
42
- logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(levelname)s - %(message)s")
43
  logger = logging.getLogger(__name__)
44
 
45
  # Ensure output directory exists
@@ -55,10 +54,9 @@ logger.info(f"Using device: {device}")
55
  # Model Loading
56
  # ==========================
57
  def load_model():
58
- model_path = CONFIG["MODEL_PATH"] if os.path.isfile(CONFIG["MODEL_PATH"]) else CONFIG["FALLBACK_MODEL"]
59
  try:
60
- model = YOLO(model_path).to(device)
61
- logger.info(f"Model loaded: {model_path}")
62
  return model
63
  except Exception as e:
64
  logger.error(f"Failed to load model: {e}")
@@ -69,13 +67,12 @@ model = load_model()
69
  # ==========================
70
  # Salesforce Integration
71
  # ==========================
72
- @retry(stop_max_attempt_number=3, wait_fixed=2000)
73
  def connect_to_salesforce():
74
  try:
75
  sf = Salesforce(**CONFIG["SF_CREDENTIALS"])
76
  logger.info("Connected to Salesforce")
77
  sf.describe()
78
- logger.debug("Salesforce object metadata fetched successfully")
79
  return sf
80
  except Exception as e:
81
  logger.error(f"Salesforce connection failed: {e}")
@@ -126,7 +123,7 @@ def generate_violation_pdf(violations, score):
126
  logger.error(f"Error generating PDF: {e}")
127
  return "", "", None
128
 
129
- @retry(stop_max_attempt_number=3, wait_fixed=2000)
130
  def upload_pdf_to_salesforce(sf, pdf_file, report_id):
131
  try:
132
  if not pdf_file:
@@ -139,9 +136,8 @@ def upload_pdf_to_salesforce(sf, pdf_file, report_id):
139
  "VersionData": encoded_pdf,
140
  "FirstPublishLocationId": report_id
141
  }
142
- logger.debug(f"Uploading PDF with data: {content_version_data}")
143
  content_version = sf.ContentVersion.create(content_version_data)
144
- result = sf.query(f"SELECT Id, ContentDocumentId FROM ContentVersion WHERE Id = '{content_version['id']}'")
145
  if not result['records']:
146
  logger.error("Failed to retrieve ContentVersion")
147
  return ""
@@ -152,7 +148,7 @@ def upload_pdf_to_salesforce(sf, pdf_file, report_id):
152
  logger.error(f"Error uploading PDF to Salesforce: {e}")
153
  return ""
154
 
155
- @retry(stop_max_attempt_number=3, wait_fixed=2000)
156
  def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
157
  try:
158
  sf = connect_to_salesforce()
@@ -169,15 +165,14 @@ def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
169
  "Status__c": "Pending",
170
  "PDF_Report_URL__c": pdf_url
171
  }
172
- logger.debug(f"Attempting to create Salesforce record with data: {record_data}")
173
  try:
174
  record = sf.Safety_Video_Report__c.create(record_data)
175
- logger.info(f"Successfully created Safety_Video_Report__c record: {record['id']}")
176
  except Exception as e:
177
  logger.error(f"Failed to create Safety_Video_Report__c: {e}")
178
- # Fallback to Account object
179
  record = sf.Account.create({"Name": f"Safety_Report_{int(time.time())}"})
180
- logger.warning(f"Fell back to creating Account record: {record['id']}")
181
  record_id = record["id"]
182
 
183
  if pdf_file:
@@ -185,16 +180,16 @@ def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
185
  if uploaded_url:
186
  try:
187
  sf.Safety_Video_Report__c.update(record_id, {"PDF_Report_URL__c": uploaded_url})
188
- logger.debug(f"Updated Safety_Video_Report__c record {record_id} with PDF URL: {uploaded_url}")
189
  except Exception as e:
190
  logger.error(f"Failed to update Safety_Video_Report__c: {e}")
191
  sf.Account.update(record_id, {"Description": uploaded_url})
192
- logger.debug(f"Updated Account record {record_id} with PDF URL: {uploaded_url}")
193
  pdf_url = uploaded_url
194
 
195
  return record_id, pdf_url
196
  except Exception as e:
197
- logger.error(f"Salesforce record creation failed: {e}", exc_info=True)
198
  return None, ""
199
 
200
  # ==========================
@@ -228,8 +223,12 @@ def process_video(video_data):
228
 
229
  violations, snapshots = [], []
230
  frame_count = 0
 
231
  fps = video.get(cv2.CAP_PROP_FPS)
232
- max_frames = int(CONFIG["MAX_PROCESSING_TIME"] * fps)
 
 
 
233
 
234
  while True:
235
  ret, frame = video.read()
@@ -240,6 +239,11 @@ def process_video(video_data):
240
  frame_count += 1
241
  continue
242
 
 
 
 
 
 
243
  results = model(frame, device=device)
244
  seen_violations = set()
245
  for result in results:
@@ -261,16 +265,19 @@ def process_video(video_data):
261
  }
262
  violations.append(violation)
263
 
264
- snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], f"snapshot_{frame_count}_{label}.jpg")
265
- cv2.imwrite(snapshot_path, frame)
266
- with open(snapshot_path, "rb") as img_file:
267
- img_base64 = base64.b64encode(img_file.read()).decode('utf-8')
268
- snapshots.append({
269
- "violation": label,
270
- "frame": frame_count,
271
- "snapshot_url": f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(snapshot_path)}",
272
- "snapshot_base64": f"data:image/jpeg;base64,{img_base64}"
273
- })
 
 
 
274
 
275
  frame_count += 1
276
 
@@ -278,7 +285,7 @@ def process_video(video_data):
278
  os.remove(video_path)
279
 
280
  if not violations:
281
- logger.info("No violations detected, skipping score calculation and reporting")
282
  return {
283
  "violations": [],
284
  "snapshots": [],
@@ -345,7 +352,7 @@ def gradio_interface(video_file):
345
  result["violation_details_url"] or "N/A"
346
  )
347
  except Exception as e:
348
- logger.error(f"Error in Gradio interface: {e}", exc_info=True)
349
  return f"Error: {str(e)}", "", "Error in processing.", "", ""
350
 
351
  interface = gr.Interface(
 
18
  # Configuration
19
  # ==========================
20
  CONFIG = {
21
+ "MODEL_PATH": "yolov8n.pt", # Force lightweight Nano model
 
22
  "OUTPUT_DIR": "static/output",
23
  "VIOLATION_LABELS": {
24
  0: "no_helmet",
 
33
  "domain": "login"
34
  },
35
  "PUBLIC_URL_BASE": "https://huggingface.co/spaces/PrashanthB461/AI_Safety_Demo1/resolve/main/static/output/",
36
+ "FRAME_SKIP": 15, # Increased to reduce frames processed
37
+ "MAX_PROCESSING_TIME": 25 # Cap video processing at 25s to leave time for reporting
38
  }
39
 
40
  # Setup logging
41
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
42
  logger = logging.getLogger(__name__)
43
 
44
  # Ensure output directory exists
 
54
  # Model Loading
55
  # ==========================
56
  def load_model():
 
57
  try:
58
+ model = YOLO(CONFIG["MODEL_PATH"]).to(device)
59
+ logger.info(f"Model loaded: {CONFIG['MODEL_PATH']}")
60
  return model
61
  except Exception as e:
62
  logger.error(f"Failed to load model: {e}")
 
67
  # ==========================
68
  # Salesforce Integration
69
  # ==========================
70
+ @retry(stop_max_attempt_number=2, wait_fixed=1000)
71
  def connect_to_salesforce():
72
  try:
73
  sf = Salesforce(**CONFIG["SF_CREDENTIALS"])
74
  logger.info("Connected to Salesforce")
75
  sf.describe()
 
76
  return sf
77
  except Exception as e:
78
  logger.error(f"Salesforce connection failed: {e}")
 
123
  logger.error(f"Error generating PDF: {e}")
124
  return "", "", None
125
 
126
+ @retry(stop_max_attempt_number=2, wait_fixed=1000)
127
  def upload_pdf_to_salesforce(sf, pdf_file, report_id):
128
  try:
129
  if not pdf_file:
 
136
  "VersionData": encoded_pdf,
137
  "FirstPublishLocationId": report_id
138
  }
 
139
  content_version = sf.ContentVersion.create(content_version_data)
140
+ result = sf.query(f"SELECT Id FROM ContentVersion WHERE Id = '{content_version['id']}'")
141
  if not result['records']:
142
  logger.error("Failed to retrieve ContentVersion")
143
  return ""
 
148
  logger.error(f"Error uploading PDF to Salesforce: {e}")
149
  return ""
150
 
151
+ @retry(stop_max_attempt_number=2, wait_fixed=1000)
152
  def push_report_to_salesforce(violations, score, pdf_path, pdf_file):
153
  try:
154
  sf = connect_to_salesforce()
 
165
  "Status__c": "Pending",
166
  "PDF_Report_URL__c": pdf_url
167
  }
168
+ logger.info(f"Creating Salesforce record with data: {record_data}")
169
  try:
170
  record = sf.Safety_Video_Report__c.create(record_data)
171
+ logger.info(f"Created Safety_Video_Report__c record: {record['id']}")
172
  except Exception as e:
173
  logger.error(f"Failed to create Safety_Video_Report__c: {e}")
 
174
  record = sf.Account.create({"Name": f"Safety_Report_{int(time.time())}"})
175
+ logger.warning(f"Fell back to Account record: {record['id']}")
176
  record_id = record["id"]
177
 
178
  if pdf_file:
 
180
  if uploaded_url:
181
  try:
182
  sf.Safety_Video_Report__c.update(record_id, {"PDF_Report_URL__c": uploaded_url})
183
+ logger.info(f"Updated record {record_id} with PDF URL: {uploaded_url}")
184
  except Exception as e:
185
  logger.error(f"Failed to update Safety_Video_Report__c: {e}")
186
  sf.Account.update(record_id, {"Description": uploaded_url})
187
+ logger.info(f"Updated Account record {record_id} with PDF URL")
188
  pdf_url = uploaded_url
189
 
190
  return record_id, pdf_url
191
  except Exception as e:
192
+ logger.error(f"Salesforce record creation failed: {e}")
193
  return None, ""
194
 
195
  # ==========================
 
223
 
224
  violations, snapshots = [], []
225
  frame_count = 0
226
+ start_time = time.time()
227
  fps = video.get(cv2.CAP_PROP_FPS)
228
+ max_frames = int(60 * fps) # Process up to 1 minute
229
+
230
+ # Track one snapshot per violation type
231
+ snapshot_taken = {"no_helmet": False, "no_harness": False, "unsafe_posture": False}
232
 
233
  while True:
234
  ret, frame = video.read()
 
239
  frame_count += 1
240
  continue
241
 
242
+ # Stop if processing time exceeds 25 seconds
243
+ if time.time() - start_time > CONFIG["MAX_PROCESSING_TIME"]:
244
+ logger.info("Processing time limit reached")
245
+ break
246
+
247
  results = model(frame, device=device)
248
  seen_violations = set()
249
  for result in results:
 
265
  }
266
  violations.append(violation)
267
 
268
+ # Save only one snapshot per violation type
269
+ if not snapshot_taken[label]:
270
+ snapshot_path = os.path.join(CONFIG["OUTPUT_DIR"], f"snapshot_{frame_count}_{label}.jpg")
271
+ cv2.imwrite(snapshot_path, frame)
272
+ with open(snapshot_path, "rb") as img_file:
273
+ img_base64 = base64.b64encode(img_file.read()).decode('utf-8')
274
+ snapshots.append({
275
+ "violation": label,
276
+ "frame": frame_count,
277
+ "snapshot_url": f"{CONFIG['PUBLIC_URL_BASE']}{os.path.basename(snapshot_path)}",
278
+ "snapshot_base64": f"data:image/jpeg;base64,{img_base64}"
279
+ })
280
+ snapshot_taken[label] = True
281
 
282
  frame_count += 1
283
 
 
285
  os.remove(video_path)
286
 
287
  if not violations:
288
+ logger.info("No violations detected")
289
  return {
290
  "violations": [],
291
  "snapshots": [],
 
352
  result["violation_details_url"] or "N/A"
353
  )
354
  except Exception as e:
355
+ logger.error(f"Error in Gradio interface: {e}")
356
  return f"Error: {str(e)}", "", "Error in processing.", "", ""
357
 
358
  interface = gr.Interface(