sparsh007 commited on
Commit
220997f
·
verified ·
1 Parent(s): 108d1fd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -166
app.py CHANGED
@@ -1,231 +1,161 @@
1
  import gradio as gr
2
- from azure.storage.blob import BlobServiceClient, BlobClient
3
  import os
4
  import cv2
5
  import tempfile
6
  from ultralytics import YOLO
7
  import logging
8
- from datetime import datetime
9
 
10
  # Configure logging
11
  logging.basicConfig(level=logging.INFO)
12
  logger = logging.getLogger(__name__)
13
 
14
- # Azure Storage Configuration
15
- AZURE_ACCOUNT_NAME = "assentian"
16
- AZURE_SAS_TOKEN = "sv=2024-11-04&ss=bfqt&srt=sco&sp=rwdlacupiytfx&se=2025-04-30T04:25:22Z&st=2025-04-16T20:25:22Z&spr=https&sig=HYrJBoOYc4PRe%2BoqBMl%2FmoL5Kz4ZYugbTLuEh63sbeo%3D"
17
- CONTAINER_NAME = "logs"
18
- VIDEO_PREFIX = ""
 
19
 
20
- # Initialize YOLO Model
21
  try:
22
- YOLO_MODEL = YOLO("./best_yolov11 (1).pt")
23
- logger.info("YOLO model loaded successfully")
24
  except Exception as e:
25
- logger.error(f"Failed to load YOLO model: {e}")
26
  raise
27
 
28
- # Azure Blob Service Client
29
- def get_blob_service_client():
30
  return BlobServiceClient(
31
- account_url=f"https://{AZURE_ACCOUNT_NAME}.blob.core.windows.net",
32
- credential=AZURE_SAS_TOKEN
33
  )
34
 
35
- def list_azure_videos():
36
  try:
37
- blob_service_client = get_blob_service_client()
38
- container_client = blob_service_client.get_container_client(CONTAINER_NAME)
39
- blobs = container_client.list_blobs(name_starts_with=VIDEO_PREFIX)
40
- return [blob.name for blob in blobs if blob.name.lower().endswith(".mp4")]
 
 
41
  except Exception as e:
42
  logger.error(f"Error listing videos: {e}")
43
  return []
44
 
45
- def get_latest_azure_video():
46
  try:
47
- blob_service_client = get_blob_service_client()
48
- container_client = blob_service_client.get_container_client(CONTAINER_NAME)
49
- blobs = container_client.list_blobs(name_starts_with=VIDEO_PREFIX)
50
-
51
- latest_blob = None
52
- latest_time = None
53
-
54
- for blob in blobs:
55
- if blob.name.lower().endswith(".mp4"):
56
- blob_client = container_client.get_blob_client(blob.name)
57
- properties = blob_client.get_blob_properties()
58
- if not latest_time or properties.last_modified > latest_time:
59
- latest_time = properties.last_modified
60
- latest_blob = blob.name
61
-
62
- return latest_blob if latest_blob else None
63
- except Exception as e:
64
- logger.error(f"Error finding latest video: {e}")
65
- return None
66
-
67
- def download_azure_video(blob_name):
68
- try:
69
- blob_service_client = get_blob_service_client()
70
- blob_client = blob_service_client.get_blob_client(
71
- container=CONTAINER_NAME,
72
  blob=blob_name
73
  )
74
 
75
- with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_file:
76
- download_stream = blob_client.download_blob()
77
- temp_file.write(download_stream.readall())
78
- return temp_file.name
79
-
80
  except Exception as e:
81
  logger.error(f"Download failed: {e}")
82
  return None
83
 
84
- def annotate_video(input_video_path):
85
  try:
86
- if not input_video_path or not os.path.exists(input_video_path):
87
- logger.error("Invalid input video path")
88
- return None
89
-
90
- cap = cv2.VideoCapture(input_video_path)
91
- if not cap.isOpened():
92
- logger.error("Failed to open video file")
93
- return None
94
-
95
- # Video writer setup
96
- frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
97
- frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
98
  fps = cap.get(cv2.CAP_PROP_FPS)
 
 
99
 
100
- with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as temp_output:
101
- output_path = temp_output.name
 
102
 
103
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
104
- writer = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height))
105
-
106
- # Processing loop
107
- while cap.isOpened():
108
  ret, frame = cap.read()
109
  if not ret:
110
  break
111
-
112
- # YOLO inference
113
- results = YOLO_MODEL(frame)
114
- class_counts = {}
115
-
116
  for result in results:
117
  for box in result.boxes:
118
- cls_id = int(box.cls[0])
119
- conf = float(box.conf[0])
120
- if conf < 0.5:
121
  continue
122
-
123
- # Bounding box
124
- x1, y1, x2, y2 = map(int, box.xyxy[0])
125
- class_name = YOLO_MODEL.names[cls_id]
126
- color = (0, 255, 0) # BGR format
127
-
128
- # Draw rectangle
129
- cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
130
 
131
- # Text label
132
- label = f"{class_name} {conf:.2f}"
133
- cv2.putText(frame, label, (x1, y1 - 10),
134
- cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
135
 
136
- # Update counts
137
- class_counts[class_name] = class_counts.get(class_name, 0) + 1
138
-
139
- # Add summary overlay
140
- summary_text = " | ".join([f"{k}: {v}" for k, v in class_counts.items()])
141
- cv2.putText(frame, summary_text, (10, 30),
142
- cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2)
143
 
144
  writer.write(frame)
145
-
146
  # Cleanup
147
  cap.release()
148
  writer.release()
149
- os.remove(input_video_path)
150
 
151
- return output_path
152
-
153
- except Exception as e:
154
- logger.error(f"Annotation failed: {e}")
155
- if 'cap' in locals(): cap.release()
156
- if 'writer' in locals(): writer.release()
157
- return None
158
-
159
- def process_video(blob_name):
160
- try:
161
- local_path = download_azure_video(blob_name)
162
- if not local_path:
163
- return None
164
- return annotate_video(local_path)
165
  except Exception as e:
166
  logger.error(f"Processing failed: {e}")
167
- return None
168
 
169
  # Gradio Interface
170
- with gr.Blocks(title="PRISM Video Annotator", theme=gr.themes.Soft()) as demo:
171
- gr.Markdown("# 🎥 PRISM Site Diary - Video Analyzer")
172
 
173
  with gr.Row():
174
- with gr.Column(scale=1):
175
- gr.Markdown("## Azure Storage Controls")
176
- refresh_btn = gr.Button("🔄 Refresh Video List", variant="secondary")
177
- video_dropdown = gr.Dropdown(
178
  label="Available Videos",
179
- choices=list_azure_videos(),
180
- interactive=True
181
- )
182
- latest_btn = gr.Button("⏩ Process Latest Video", variant="primary")
183
- selected_btn = gr.Button("✅ Process Selected Video", variant="primary")
184
-
185
- with gr.Column(scale=2):
186
- gr.Markdown("## Annotated Output")
187
- output_video = gr.Video(
188
- label="Processed Video",
189
- format="mp4",
190
- interactive=False
191
  )
192
- status = gr.Textbox(label="Processing Status")
193
-
194
- def update_ui():
195
- new_choices = list_azure_videos()
196
- return gr.Dropdown.update(choices=new_choices)
197
-
198
- def handle_latest():
199
- latest = get_latest_azure_video()
200
- if latest:
201
- output = process_video(latest)
202
- return output if output else None
203
- return None
204
 
205
  # Event handlers
206
- refresh_btn.click(
207
- fn=update_ui,
208
- outputs=video_dropdown,
209
- queue=False
210
- )
211
 
212
- latest_btn.click(
213
- fn=handle_latest,
214
- outputs=output_video,
215
- api_name="process_latest"
216
- )
 
 
 
 
 
 
 
217
 
218
- selected_btn.click(
219
- fn=lambda x: process_video(x),
220
- inputs=video_dropdown,
221
- outputs=output_video,
222
- api_name="process_selected"
223
  )
224
 
225
  if __name__ == "__main__":
226
- demo.launch(
227
- server_name="0.0.0.0",
228
- server_port=7860,
229
- show_error=True,
230
- share=False
231
- )
 
1
  import gradio as gr
2
+ from azure.storage.blob import BlobServiceClient
3
  import os
4
  import cv2
5
  import tempfile
6
  from ultralytics import YOLO
7
  import logging
8
+ import time
9
 
10
  # Configure logging
11
  logging.basicConfig(level=logging.INFO)
12
  logger = logging.getLogger(__name__)
13
 
14
+ # Azure Configuration
15
+ AZURE_CONFIG = {
16
+ "account_name": "assentian",
17
+ "sas_token": "sv=2024-11-04&ss=bfqt&srt=sco&sp=rwdlacupiytfx&se=2025-04-30T04:25:22Z&st=2025-04-16T20:25:22Z&spr=https&sig=HYrJBoOYc4PRe%2BoqBMl%2FmoL5Kz4ZYugbTLuEh63sbeo%3D",
18
+ "container_name": "logs"
19
+ }
20
 
21
+ # YOLO Model
22
  try:
23
+ MODEL = YOLO("./best_yolov11 (1).pt")
24
+ logger.info("Model loaded successfully")
25
  except Exception as e:
26
+ logger.error(f"Model loading failed: {e}")
27
  raise
28
 
29
+ def get_azure_client():
 
30
  return BlobServiceClient(
31
+ account_url=f"https://{AZURE_CONFIG['account_name']}.blob.core.windows.net",
32
+ credential=AZURE_CONFIG['sas_token']
33
  )
34
 
35
+ def list_videos():
36
  try:
37
+ client = get_azure_client()
38
+ container = client.get_container_client(AZURE_CONFIG['container_name'])
39
+ return [
40
+ blob.name for blob in container.list_blobs()
41
+ if blob.name.lower().endswith(".mp4")
42
+ ]
43
  except Exception as e:
44
  logger.error(f"Error listing videos: {e}")
45
  return []
46
 
47
+ def get_video(blob_name):
48
  try:
49
+ client = get_azure_client()
50
+ blob = client.get_blob_client(
51
+ container=AZURE_CONFIG['container_name'],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  blob=blob_name
53
  )
54
 
55
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as f:
56
+ f.write(blob.download_blob().readall())
57
+ return f.name
 
 
58
  except Exception as e:
59
  logger.error(f"Download failed: {e}")
60
  return None
61
 
62
+ def process_video(input_path, progress=gr.Progress()):
63
  try:
64
+ # Validate input
65
+ if not input_path or not os.path.exists(input_path):
66
+ return None, "Invalid input video"
67
+
68
+ # Video setup
69
+ cap = cv2.VideoCapture(input_path)
70
+ frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
 
 
 
 
 
71
  fps = cap.get(cv2.CAP_PROP_FPS)
72
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
73
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
74
 
75
+ # Output setup
76
+ output_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4").name
77
+ writer = cv2.VideoWriter(output_file, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
78
 
79
+ # Process frames
80
+ progress(0, desc="Starting processing...")
81
+ for frame_num in progress.tqdm(range(frame_count), desc="Processing"):
 
 
82
  ret, frame = cap.read()
83
  if not ret:
84
  break
85
+
86
+ # YOLO detection
87
+ results = MODEL(frame)
 
 
88
  for result in results:
89
  for box in result.boxes:
90
+ if box.conf.item() < 0.5:
 
 
91
  continue
 
 
 
 
 
 
 
 
92
 
93
+ # Draw bounding box
94
+ x1, y1, x2, y2 = map(int, box.xyxy[0].tolist())
95
+ cv2.rectangle(frame, (x1, y1), (x2, y2), (0,255,0), 2)
 
96
 
97
+ # Add label
98
+ label = f"{MODEL.names[int(box.cls)]} {box.conf:.2f}"
99
+ cv2.putText(frame, label, (x1, y1-10),
100
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255,255,255), 2)
 
 
 
101
 
102
  writer.write(frame)
103
+
104
  # Cleanup
105
  cap.release()
106
  writer.release()
107
+ os.remove(input_path)
108
 
109
+ return output_file, "Processing completed successfully"
110
+
 
 
 
 
 
 
 
 
 
 
 
 
111
  except Exception as e:
112
  logger.error(f"Processing failed: {e}")
113
+ return None, f"Error: {str(e)}"
114
 
115
  # Gradio Interface
116
+ with gr.Blocks(theme=gr.themes.Soft()) as app:
117
+ gr.Markdown("# Video Processing Interface")
118
 
119
  with gr.Row():
120
+ with gr.Column():
121
+ gr.Markdown("## Video Selection")
122
+ video_select = gr.Dropdown(
 
123
  label="Available Videos",
124
+ choices=list_videos(),
125
+ interactive=True,
126
+ filterable=False
 
 
 
 
 
 
 
 
 
127
  )
128
+ refresh_btn = gr.Button("🔄 Refresh List")
129
+ process_btn = gr.Button("🚀 Process Selected Video", variant="primary")
130
+
131
+ with gr.Column():
132
+ gr.Markdown("## Output")
133
+ video_output = gr.Video(label="Processed Video", format="mp4")
134
+ status = gr.Textbox(label="Status", interactive=False)
 
 
 
 
 
135
 
136
  # Event handlers
137
+ def refresh_list():
138
+ return gr.Dropdown.update(choices=list_videos())
 
 
 
139
 
140
+ def handle_process(blob_name):
141
+ start_time = time.time()
142
+ if not blob_name:
143
+ return None, "No video selected!"
144
+
145
+ local_path = get_video(blob_name)
146
+ if not local_path:
147
+ return None, "Download failed!"
148
+
149
+ result, message = process_video(local_path)
150
+ duration = time.time() - start_time
151
+ return result, f"{message} | Time: {duration:.1f}s"
152
 
153
+ refresh_btn.click(refresh_list, outputs=video_select)
154
+ process_btn.click(
155
+ handle_process,
156
+ inputs=video_select,
157
+ outputs=[video_output, status]
158
  )
159
 
160
  if __name__ == "__main__":
161
+ app.launch(server_name="0.0.0.0", server_port=7860)