PrashanthB461 commited on
Commit
c3ce28d
·
verified ·
1 Parent(s): d157456

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -73
app.py CHANGED
@@ -2,21 +2,6 @@ import os
2
  import shutil
3
  import sys
4
  import warnings
5
-
6
- # Suppress all warnings
7
- warnings.filterwarnings("ignore")
8
-
9
- # Clean up and set YOLO_CONFIG_DIR before any Ultralytics imports
10
- yolo_config_dir = "/tmp/Ultralytics"
11
- if os.path.exists(yolo_config_dir):
12
- shutil.rmtree(yolo_config_dir)
13
- os.makedirs(yolo_config_dir, exist_ok=True)
14
- os.environ["YOLO_CONFIG_DIR"] = yolo_config_dir
15
-
16
- # Redirect stdout temporarily to suppress Ultralytics initialization messages
17
- original_stdout = sys.stdout
18
- sys.stdout = open(os.devnull, 'w')
19
-
20
  from fastapi import FastAPI, File, UploadFile, HTTPException
21
  from contextlib import asynccontextmanager
22
  import torch
@@ -31,79 +16,61 @@ from datetime import datetime
31
  import logging
32
  import pandas as pd
33
 
34
- # Restore stdout
35
- sys.stdout = original_stdout
 
 
 
 
 
 
 
 
36
 
37
- # --- Configuration ---
 
38
  logging.basicConfig(
39
  level=logging.INFO,
40
  format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
41
- handlers=[
42
- logging.StreamHandler()
43
- ]
44
  )
 
 
45
  logger = logging.getLogger(__name__)
46
 
47
- # Suppress Ultralytics and other unnecessary logging
48
  logging.getLogger("ultralytics").setLevel(logging.WARNING)
49
  logging.getLogger("PIL").setLevel(logging.WARNING)
50
  logging.getLogger("matplotlib").setLevel(logging.WARNING)
 
51
 
52
- # Log environment variable for debugging
53
  logger.info(f"YOLO_CONFIG_DIR set to: {os.getenv('YOLO_CONFIG_DIR')}")
54
 
 
55
  RTSP_URL = os.getenv("RTSP_URL", "rtsp://localhost:8554/stream")
56
  SALESFORCE_URL = os.getenv("SALESFORCE_URL", "https://your_salesforce_instance_url")
57
  SALESFORCE_TOKEN = os.getenv("SALESFORCE_TOKEN", "your_salesforce_access_token")
58
  HUGGINGFACE_API_URL = os.getenv("HUGGINGFACE_API_URL", "https://api-inference.huggingface.co/models/PrashanthB461/SafetyViolationAI1")
59
  HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN", "your_huggingface_api_token")
60
 
61
- # Global model instance
62
  yolo_model = None
63
 
64
- # --- Lifespan Handler ---
65
- @asynccontextmanager
66
- async def lifespan(app: FastAPI):
67
- # Startup logic
68
- global yolo_model
69
- logger.info("FastAPI application starting up")
70
- try:
71
- # Suppress Ultralytics initialization messages
72
- with open(os.devnull, 'w') as f:
73
- original_stdout = sys.stdout
74
- sys.stdout = f
75
- yolo_model = YOLOv8Model()
76
- sys.stdout = original_stdout
77
- logger.info("YOLOv8 model initialized successfully")
78
- except Exception as e:
79
- logger.error(f"Startup error: {e}")
80
- raise
81
- yield
82
- # Shutdown logic
83
- logger.info("FastAPI application shutting down")
84
- yolo_model = None
85
-
86
- # --- Initialize FastAPI app ---
87
- app = FastAPI(
88
- lifespan=lifespan,
89
- title="Safety Violation Detection API",
90
- description="API for detecting safety violations using YOLOv8",
91
- version="1.0.0"
92
- )
93
-
94
- # --- YOLOv8 Model ---
95
  class YOLOv8Model:
96
  def __init__(self, model_path='yolov8n.pt'):
97
  try:
98
  logger.info("Initializing YOLOv8 model")
99
- # Ensure the config directory exists
100
- os.makedirs(os.environ["YOLO_CONFIG_DIR"], exist_ok=True)
101
 
102
- # Suppress model loading messages
103
- with open(os.devnull, 'w') as f:
104
- original_stdout = sys.stdout
105
- sys.stdout = f
106
- self.model = YOLO(model_path) # Load YOLOv8 model
 
 
 
107
  sys.stdout = original_stdout
108
 
109
  logger.info("YOLOv8 model loaded successfully")
@@ -113,14 +80,13 @@ class YOLOv8Model:
113
 
114
  def predict(self, image):
115
  try:
116
- # Run prediction
117
- results = self.model(image) # Inference
118
- return results.pandas().xyxy[0] # Bounding boxes, class names, confidence score
119
  except Exception as e:
120
  logger.error(f"Prediction error: {e}")
121
  raise
122
 
123
- # --- Preprocessing RTSP Frame ---
124
  def preprocess_frame(frame):
125
  try:
126
  img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
@@ -130,13 +96,13 @@ def preprocess_frame(frame):
130
  logger.error(f"Frame preprocessing error: {e}")
131
  raise
132
 
133
- # --- RTSP Stream Handler ---
134
  def capture_rtsp_frames(rtsp_url: str):
135
  try:
136
  cap = cv2.VideoCapture(rtsp_url)
137
  if not cap.isOpened():
138
  logger.error(f"Failed to open RTSP stream: {rtsp_url}")
139
  raise ValueError("RTSP stream not accessible")
 
140
  while cap.isOpened():
141
  ret, frame = cap.read()
142
  if ret:
@@ -150,14 +116,14 @@ def capture_rtsp_frames(rtsp_url: str):
150
  logger.error(f"RTSP capture error: {e}")
151
  raise
152
 
153
- # --- Save Violations and Snapshots ---
154
  def save_snapshot(frame):
155
  try:
156
  filename = f"snapshot_{int(time.time())}.jpg"
157
  snapshot_path = Path("/snapshots") / filename
158
  os.makedirs("/snapshots", exist_ok=True)
159
  cv2.imwrite(str(snapshot_path), frame)
160
- return f"/snapshots/{filename}" # Relative path for containerized env
161
  except Exception as e:
162
  logger.error(f"Snapshot saving error: {e}")
163
  raise
@@ -176,12 +142,9 @@ def log_violation(violation_data):
176
  logger.error(f"Violation logging error: {e}")
177
  raise
178
 
179
- # --- Notification System ---
180
  def send_alert(violation):
181
  logger.info(f"Alert! {violation['violation_type']} detected. Severity: {violation['severity']}")
182
- # Placeholder for email/SMS notification logic
183
 
184
- # --- Salesforce Integration ---
185
  def create_salesforce_violation_record(violation_data):
186
  try:
187
  salesforce_url = f"{SALESFORCE_URL}/services/data/v60.0/sobjects/Safety_Violation_Log__c/"
@@ -206,13 +169,36 @@ def create_salesforce_violation_record(violation_data):
206
  logger.error(f"Salesforce integration error: {e}")
207
  raise
208
 
209
- # --- API Routes ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  @app.post("/detect_violation/")
211
  async def detect_violation():
212
  try:
213
  global yolo_model
214
  if yolo_model is None:
215
  raise HTTPException(status_code=500, detail="YOLO model not initialized")
 
216
  for frame, timestamp in capture_rtsp_frames(RTSP_URL):
217
  frame_processed = preprocess_frame(frame)
218
  results = yolo_model.predict(frame_processed)
@@ -242,6 +228,7 @@ async def upload_image(file: UploadFile = File(...)):
242
  global yolo_model
243
  if yolo_model is None:
244
  raise HTTPException(status_code=500, detail="YOLO model not initialized")
 
245
  image_data = await file.read()
246
  image = cv2.imdecode(np.frombuffer(image_data, np.uint8), cv2.IMREAD_COLOR)
247
  results = yolo_model.predict(image)
 
2
  import shutil
3
  import sys
4
  import warnings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  from fastapi import FastAPI, File, UploadFile, HTTPException
6
  from contextlib import asynccontextmanager
7
  import torch
 
16
  import logging
17
  import pandas as pd
18
 
19
+ # --- Initial Configuration ---
20
+ # Suppress all warnings
21
+ warnings.filterwarnings("ignore")
22
+
23
+ # Clean up and set YOLO_CONFIG_DIR
24
+ yolo_config_dir = "/tmp/Ultralytics"
25
+ if os.path.exists(yolo_config_dir):
26
+ shutil.rmtree(yolo_config_dir)
27
+ os.makedirs(yolo_config_dir, exist_ok=True)
28
+ os.environ["YOLO_CONFIG_DIR"] = yolo_config_dir
29
 
30
+ # --- Logging Configuration ---
31
+ # Configure logging before any other operations
32
  logging.basicConfig(
33
  level=logging.INFO,
34
  format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
35
+ handlers=[logging.StreamHandler()]
 
 
36
  )
37
+
38
+ # Get logger instance
39
  logger = logging.getLogger(__name__)
40
 
41
+ # Suppress third-party logging
42
  logging.getLogger("ultralytics").setLevel(logging.WARNING)
43
  logging.getLogger("PIL").setLevel(logging.WARNING)
44
  logging.getLogger("matplotlib").setLevel(logging.WARNING)
45
+ logging.getLogger("urllib3").setLevel(logging.WARNING)
46
 
47
+ # Log environment variable once
48
  logger.info(f"YOLO_CONFIG_DIR set to: {os.getenv('YOLO_CONFIG_DIR')}")
49
 
50
+ # --- Environment Variables ---
51
  RTSP_URL = os.getenv("RTSP_URL", "rtsp://localhost:8554/stream")
52
  SALESFORCE_URL = os.getenv("SALESFORCE_URL", "https://your_salesforce_instance_url")
53
  SALESFORCE_TOKEN = os.getenv("SALESFORCE_TOKEN", "your_salesforce_access_token")
54
  HUGGINGFACE_API_URL = os.getenv("HUGGINGFACE_API_URL", "https://api-inference.huggingface.co/models/PrashanthB461/SafetyViolationAI1")
55
  HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN", "your_huggingface_api_token")
56
 
57
+ # --- Global Model Instance ---
58
  yolo_model = None
59
 
60
+ # --- Model Class ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  class YOLOv8Model:
62
  def __init__(self, model_path='yolov8n.pt'):
63
  try:
64
  logger.info("Initializing YOLOv8 model")
 
 
65
 
66
+ # Temporarily suppress stdout for model loading
67
+ original_stdout = sys.stdout
68
+ sys.stdout = open(os.devnull, 'w')
69
+
70
+ try:
71
+ self.model = YOLO(model_path)
72
+ finally:
73
+ sys.stdout.close()
74
  sys.stdout = original_stdout
75
 
76
  logger.info("YOLOv8 model loaded successfully")
 
80
 
81
  def predict(self, image):
82
  try:
83
+ results = self.model(image)
84
+ return results.pandas().xyxy[0]
 
85
  except Exception as e:
86
  logger.error(f"Prediction error: {e}")
87
  raise
88
 
89
+ # --- Frame Processing Functions ---
90
  def preprocess_frame(frame):
91
  try:
92
  img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
 
96
  logger.error(f"Frame preprocessing error: {e}")
97
  raise
98
 
 
99
  def capture_rtsp_frames(rtsp_url: str):
100
  try:
101
  cap = cv2.VideoCapture(rtsp_url)
102
  if not cap.isOpened():
103
  logger.error(f"Failed to open RTSP stream: {rtsp_url}")
104
  raise ValueError("RTSP stream not accessible")
105
+
106
  while cap.isOpened():
107
  ret, frame = cap.read()
108
  if ret:
 
116
  logger.error(f"RTSP capture error: {e}")
117
  raise
118
 
119
+ # --- Violation Handling Functions ---
120
  def save_snapshot(frame):
121
  try:
122
  filename = f"snapshot_{int(time.time())}.jpg"
123
  snapshot_path = Path("/snapshots") / filename
124
  os.makedirs("/snapshots", exist_ok=True)
125
  cv2.imwrite(str(snapshot_path), frame)
126
+ return f"/snapshots/{filename}"
127
  except Exception as e:
128
  logger.error(f"Snapshot saving error: {e}")
129
  raise
 
142
  logger.error(f"Violation logging error: {e}")
143
  raise
144
 
 
145
  def send_alert(violation):
146
  logger.info(f"Alert! {violation['violation_type']} detected. Severity: {violation['severity']}")
 
147
 
 
148
  def create_salesforce_violation_record(violation_data):
149
  try:
150
  salesforce_url = f"{SALESFORCE_URL}/services/data/v60.0/sobjects/Safety_Violation_Log__c/"
 
169
  logger.error(f"Salesforce integration error: {e}")
170
  raise
171
 
172
+ # --- FastAPI Application ---
173
+ @asynccontextmanager
174
+ async def lifespan(app: FastAPI):
175
+ global yolo_model
176
+ logger.info("FastAPI application starting up")
177
+ try:
178
+ yolo_model = YOLOv8Model()
179
+ logger.info("YOLOv8 model initialized successfully")
180
+ except Exception as e:
181
+ logger.error(f"Startup error: {e}")
182
+ raise
183
+ yield
184
+ logger.info("FastAPI application shutting down")
185
+ yolo_model = None
186
+
187
+ app = FastAPI(
188
+ lifespan=lifespan,
189
+ title="Safety Violation Detection API",
190
+ description="API for detecting safety violations using YOLOv8",
191
+ version="1.0.0"
192
+ )
193
+
194
+ # --- API Endpoints ---
195
  @app.post("/detect_violation/")
196
  async def detect_violation():
197
  try:
198
  global yolo_model
199
  if yolo_model is None:
200
  raise HTTPException(status_code=500, detail="YOLO model not initialized")
201
+
202
  for frame, timestamp in capture_rtsp_frames(RTSP_URL):
203
  frame_processed = preprocess_frame(frame)
204
  results = yolo_model.predict(frame_processed)
 
228
  global yolo_model
229
  if yolo_model is None:
230
  raise HTTPException(status_code=500, detail="YOLO model not initialized")
231
+
232
  image_data = await file.read()
233
  image = cv2.imdecode(np.frombuffer(image_data, np.uint8), cv2.IMREAD_COLOR)
234
  results = yolo_model.predict(image)