PrashanthB461 commited on
Commit
69350d2
·
verified ·
1 Parent(s): 473a717

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +113 -69
app.py CHANGED
@@ -1,5 +1,6 @@
1
  from fastapi import FastAPI, File, UploadFile, HTTPException
2
  import torch
 
3
  import cv2
4
  import requests
5
  import json
@@ -7,96 +8,125 @@ import time
7
  import os
8
  import numpy as np
9
  from pathlib import Path
 
 
10
 
11
  # --- Configuration ---
12
- RTSP_URL = "rtsp://your_rtsp_stream_url"
13
- SALESFORCE_URL = "https://your_salesforce_instance_url"
14
- SALESFORCE_TOKEN = "your_salesforce_access_token"
15
- HUGGINGFACE_API_URL = "https://huggingface.co/your_model_endpoint"
16
- HUGGINGFACE_TOKEN = "your_huggingface_api_token"
 
 
 
17
 
18
  # --- Initialize FastAPI app ---
19
  app = FastAPI()
20
 
21
  # --- YOLOv8 Model ---
22
  class YOLOv8Model:
23
- def __init__(self, model_name='yolov8'):
24
- self.model = torch.hub.load('ultralytics/yolov5', model_name) # YOLOv8 based on YOLOv5
25
- self.model.eval()
 
 
 
 
26
 
27
  def predict(self, image):
28
- results = self.model(image) # Inference
29
- return results.pandas().xywh[0] # Bounding boxes, class names, confidence score
30
-
 
 
 
31
 
32
  # --- Preprocessing RTSP Frame ---
33
  def preprocess_frame(frame):
34
- img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # Convert frame to RGB
35
- img_resized = cv2.resize(img, (640, 640)) # Resize to 640x640 (YOLOv8 input size)
36
- return img_resized
37
-
 
 
 
38
 
39
  # --- RTSP Stream Handler ---
40
  def capture_rtsp_frames(rtsp_url: str):
41
- cap = cv2.VideoCapture(rtsp_url)
42
- while cap.isOpened():
43
- ret, frame = cap.read()
44
- if ret:
45
- timestamp = time.time()
46
- yield frame, timestamp
47
- else:
48
- break
49
- cap.release()
50
-
 
 
 
 
 
 
 
51
 
52
  # --- Save Violations and Snapshots ---
53
  def save_snapshot(frame):
54
- filename = f"snapshot_{int(time.time())}.jpg"
55
- snapshot_path = Path("./snapshots") / filename
56
- os.makedirs("./snapshots", exist_ok=True)
57
- cv2.imwrite(str(snapshot_path), frame)
58
- return f"http://localhost/snapshots/{filename}" # URL for local testing
59
-
 
 
 
60
 
61
  def log_violation(violation_data):
62
- log_file = Path("./violation_logs.json")
63
- if log_file.exists():
64
- with open(log_file, "r") as f:
65
- logs = json.load(f)
66
- else:
67
  logs = []
68
-
69
- logs.append(violation_data)
70
-
71
- with open(log_file, "w") as f:
72
- json.dump(logs, f, indent=4)
73
-
 
 
 
74
 
75
  # --- Notification System ---
76
  def send_alert(violation):
77
- print(f"Alert! {violation['violation_type']} detected. Severity: {violation['severity']}")
78
- # Implement your notification logic (email/SMS) here
79
-
80
 
81
  # --- Salesforce Integration ---
82
  def create_salesforce_violation_record(violation_data):
83
- salesforce_url = f"{SALESFORCE_URL}/services/data/vXX.0/sobjects/Safety_Violation_Log__c/"
84
- headers = {
85
- 'Authorization': f'Bearer {SALESFORCE_TOKEN}',
86
- 'Content-Type': 'application/json'
87
- }
88
- violation_obj = {
89
- 'Site_ID__c': violation_data['site_id'],
90
- 'Violation_Type__c': violation_data['violation_type'],
91
- 'Timestamp__c': violation_data['timestamp'],
92
- 'Snapshot_URL__c': violation_data['snapshot_url'],
93
- 'Severity__c': violation_data['severity'],
94
- 'Alert_Sent__c': True,
95
- 'Resolved__c': False
96
- }
97
- response = requests.post(salesforce_url, headers=headers, data=json.dumps(violation_obj))
98
- return response.json()
99
-
 
 
 
 
 
100
 
101
  # --- API Routes ---
102
  @app.post("/detect_violation/")
@@ -108,21 +138,24 @@ async def detect_violation():
108
  results = model.predict(frame_processed)
109
 
110
  for index, row in results.iterrows():
 
111
  violation = {
112
- 'site_id': "Site1", # Placeholder, should be dynamic
 
113
  'violation_type': row['name'],
114
  'timestamp': timestamp,
115
  'snapshot_url': save_snapshot(frame),
116
- 'severity': row['confidence']
117
  }
118
- create_salesforce_violation_record(violation) # Log to Salesforce
119
- send_alert(violation) # Send alert to site HSE
 
120
 
121
  return {"status": "Violation detection complete."}
122
  except Exception as e:
 
123
  raise HTTPException(status_code=500, detail=f"Error processing stream: {e}")
124
 
125
-
126
  @app.post("/upload_image/")
127
  async def upload_image(file: UploadFile = File(...)):
128
  try:
@@ -132,9 +165,20 @@ async def upload_image(file: UploadFile = File(...)):
132
  results = model.predict(image)
133
  return {"results": results.to_dict()}
134
  except Exception as e:
 
135
  raise HTTPException(status_code=500, detail=f"Error processing image: {e}")
136
 
137
-
138
  @app.get("/health_check/")
139
  async def health_check():
140
- return {"status": "Running smoothly"}
 
 
 
 
 
 
 
 
 
 
 
 
1
  from fastapi import FastAPI, File, UploadFile, HTTPException
2
  import torch
3
+ from ultralytics import YOLO
4
  import cv2
5
  import requests
6
  import json
 
8
  import os
9
  import numpy as np
10
  from pathlib import Path
11
+ from datetime import datetime
12
+ import logging
13
 
14
  # --- Configuration ---
15
+ logging.basicConfig(level=logging.INFO)
16
+ logger = logging.getLogger(__name__)
17
+
18
+ RTSP_URL = os.getenv("RTSP_URL", "rtsp://localhost:8554/stream")
19
+ SALESFORCE_URL = os.getenv("SALESFORCE_URL", "https://your_salesforce_instance_url")
20
+ SALESFORCE_TOKEN = os.getenv("SALESFORCE_TOKEN", "your_salesforce_access_token")
21
+ HUGGINGFACE_API_URL = os.getenv("HUGGINGFACE_API_URL", "https://api-inference.huggingface.co/models/PrashanthB461/SafetyViolationAI1")
22
+ HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN", "your_huggingface_api_token")
23
 
24
  # --- Initialize FastAPI app ---
25
  app = FastAPI()
26
 
27
  # --- YOLOv8 Model ---
28
  class YOLOv8Model:
29
+ def __init__(self, model_path='yolov8n.pt'):
30
+ try:
31
+ self.model = YOLO(model_path) # Load YOLOv8 model
32
+ logger.info("YOLOv8 model loaded successfully")
33
+ except Exception as e:
34
+ logger.error(f"Failed to load YOLOv8 model: {e}")
35
+ raise
36
 
37
  def predict(self, image):
38
+ try:
39
+ results = self.model(image) # Inference
40
+ return results.pandas().xyxy[0] # Bounding boxes, class names, confidence score
41
+ except Exception as e:
42
+ logger.error(f"Prediction error: {e}")
43
+ raise
44
 
45
  # --- Preprocessing RTSP Frame ---
46
  def preprocess_frame(frame):
47
+ try:
48
+ img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
49
+ img_resized = cv2.resize(img, (640, 640))
50
+ return img_resized
51
+ except Exception as e:
52
+ logger.error(f"Frame preprocessing error: {e}")
53
+ raise
54
 
55
  # --- RTSP Stream Handler ---
56
  def capture_rtsp_frames(rtsp_url: str):
57
+ try:
58
+ cap = cv2.VideoCapture(rtsp_url)
59
+ if not cap.isOpened():
60
+ logger.error(f"Failed to open RTSP stream: {rtsp_url}")
61
+ raise ValueError("RTSP stream not accessible")
62
+ while cap.isOpened():
63
+ ret, frame = cap.read()
64
+ if ret:
65
+ timestamp = datetime.utcnow().isoformat()
66
+ yield frame, timestamp
67
+ else:
68
+ logger.warning("Failed to read frame from RTSP stream")
69
+ break
70
+ cap.release()
71
+ except Exception as e:
72
+ logger.error(f"RTSP capture error: {e}")
73
+ raise
74
 
75
  # --- Save Violations and Snapshots ---
76
  def save_snapshot(frame):
77
+ try:
78
+ filename = f"snapshot_{int(time.time())}.jpg"
79
+ snapshot_path = Path("/snapshots") / filename
80
+ os.makedirs("/snapshots", exist_ok=True)
81
+ cv2.imwrite(str(snapshot_path), frame)
82
+ return f"/snapshots/{filename}" # Relative path for containerized env
83
+ except Exception as e:
84
+ logger.error(f"Snapshot saving error: {e}")
85
+ raise
86
 
87
  def log_violation(violation_data):
88
+ try:
89
+ log_file = Path("/snapshots/violation_logs.json")
 
 
 
90
  logs = []
91
+ if log_file.exists():
92
+ with open(log_file, "r") as f:
93
+ logs = json.load(f)
94
+ logs.append(violation_data)
95
+ with open(log_file, "w") as f:
96
+ json.dump(logs, f, indent=4)
97
+ except Exception as e:
98
+ logger.error(f"Violation logging error: {e}")
99
+ raise
100
 
101
  # --- Notification System ---
102
  def send_alert(violation):
103
+ logger.info(f"Alert! {violation['violation_type']} detected. Severity: {violation['severity']}")
104
+ # Placeholder for email/SMS notification logic
 
105
 
106
  # --- Salesforce Integration ---
107
  def create_salesforce_violation_record(violation_data):
108
+ try:
109
+ salesforce_url = f"{SALESFORCE_URL}/services/data/v60.0/sobjects/Safety_Violation_Log__c/"
110
+ headers = {
111
+ 'Authorization': f'Bearer {SALESFORCE_TOKEN}',
112
+ 'Content-Type': 'application/json'
113
+ }
114
+ violation_obj = {
115
+ 'Site_ID__c': violation_data['site_id'],
116
+ 'Camera_ID__c': violation_data['camera_id'],
117
+ 'Violation_Type__c': violation_data['violation_type'],
118
+ 'Timestamp__c': violation_data['timestamp'],
119
+ 'Snapshot_URL__c': violation_data['snapshot_url'],
120
+ 'Severity__c': violation_data['severity'],
121
+ 'Alert_Sent__c': True,
122
+ 'Resolved__c': False
123
+ }
124
+ response = requests.post(salesforce_url, headers=headers, data=json.dumps(violation_obj))
125
+ response.raise_for_status()
126
+ return response.json()
127
+ except Exception as e:
128
+ logger.error(f"Salesforce integration error: {e}")
129
+ raise
130
 
131
  # --- API Routes ---
132
  @app.post("/detect_violation/")
 
138
  results = model.predict(frame_processed)
139
 
140
  for index, row in results.iterrows():
141
+ severity = "Critical" if row['conf'] > 0.8 else "Moderate" if row['conf'] > 0.5 else "Minor"
142
  violation = {
143
+ 'site_id': "Site1",
144
+ 'camera_id': "Camera1",
145
  'violation_type': row['name'],
146
  'timestamp': timestamp,
147
  'snapshot_url': save_snapshot(frame),
148
+ 'severity': severity
149
  }
150
+ log_violation(violation)
151
+ create_salesforce_violation_record(violation)
152
+ send_alert(violation)
153
 
154
  return {"status": "Violation detection complete."}
155
  except Exception as e:
156
+ logger.error(f"Error processing stream: {e}")
157
  raise HTTPException(status_code=500, detail=f"Error processing stream: {e}")
158
 
 
159
  @app.post("/upload_image/")
160
  async def upload_image(file: UploadFile = File(...)):
161
  try:
 
165
  results = model.predict(image)
166
  return {"results": results.to_dict()}
167
  except Exception as e:
168
+ logger.error(f"Error processing image: {e}")
169
  raise HTTPException(status_code=500, detail=f"Error processing image: {e}")
170
 
 
171
  @app.get("/health_check/")
172
  async def health_check():
173
+ return {"status": "Running smoothly"}
174
+
175
+ @app.on_event("startup")
176
+ async def startup_event():
177
+ logger.info("FastAPI application starting up")
178
+ # Initialize any resources (e.g., check RTSP connection, model load)
179
+ try:
180
+ model = YOLOv8Model()
181
+ logger.info("Startup: YOLOv8 model initialized")
182
+ except Exception as e:
183
+ logger.error(f"Startup error: {e}")
184
+ raise