sparsh007's picture
Update app.py
ec84a0f verified
import gradio as gr
from azure.storage.blob import BlobServiceClient
import os
import cv2
import tempfile
from ultralytics import YOLO
import logging
import time
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Azure Configuration
AZURE_CONFIG = {
"account_name": "assentian",
"sas_token": "sv=2024-11-04&ss=bfqt&srt=sco&sp=rwdlacupiytfx&se=2025-04-30T04:25:22Z&st=2025-04-16T20:25:22Z&spr=https&sig=HYrJBoOYc4PRe%2BoqBMl%2FmoL5Kz4ZYugbTLuEh63sbeo%3D",
"container_name": "logs",
"max_size_mb": 500
}
# YOLO Model Configuration
MODEL_CONFIG = {
"model_path": "./best_yolov11 (1).pt",
"conf_threshold": 0.5,
"frame_skip": 0 # Process every frame for testing
}
# Initialize YOLO Model
try:
MODEL = YOLO(MODEL_CONFIG["model_path"])
logger.info(f"Loaded YOLO model: {MODEL_CONFIG['model_path']}")
except Exception as e:
logger.error(f"Model loading failed: {e}")
raise
def get_azure_client():
return BlobServiceClient(
account_url=f"https://{AZURE_CONFIG['account_name']}.blob.core.windows.net",
credential=AZURE_CONFIG['sas_token']
)
def list_videos():
try:
client = get_azure_client()
container = client.get_container_client(AZURE_CONFIG['container_name'])
return [
blob.name for blob in container.list_blobs()
if blob.name.lower().endswith(".mp4")
]
except Exception as e:
logger.error(f"Error listing videos: {e}")
return []
def validate_video_size(blob_client):
props = blob_client.get_blob_properties()
size_mb = props.size / (1024 * 1024)
if size_mb > AZURE_CONFIG["max_size_mb"]:
raise ValueError(f"Video exceeds {AZURE_CONFIG['max_size_mb']}MB limit")
def download_video(blob_name):
try:
client = get_azure_client()
blob = client.get_blob_client(
container=AZURE_CONFIG['container_name'],
blob=blob_name
)
validate_video_size(blob)
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as f:
download_stream = blob.download_blob()
for chunk in download_stream.chunks():
f.write(chunk)
return f.name
except Exception as e:
logger.error(f"Download failed: {e}")
return None
def process_video(input_path, progress=gr.Progress()):
try:
if not input_path or not os.path.exists(input_path):
raise ValueError("Invalid input video path")
cap = cv2.VideoCapture(input_path)
if not cap.isOpened():
raise RuntimeError("Failed to open video file")
# Get video properties with 200 frame limit
original_frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
frame_count = min(original_frame_count, 200) # TESTING LIMIT
fps = cap.get(cv2.CAP_PROP_FPS)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
# Output setup
output_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4").name
writer = cv2.VideoWriter(output_file,
cv2.VideoWriter_fourcc(*'mp4v'),
fps,
(width, height))
processed_frames = 0
total_processed = 0
progress(0, desc="Processing first 200 frames...")
start_time = time.time()
while cap.isOpened() and total_processed < 200: # FRAME LIMIT
ret, frame = cap.read()
if not ret:
break
# Process every frame (frame_skip = 0)
results = MODEL(frame, verbose=False)
class_counts = {}
for result in results:
for box in result.boxes:
conf = box.conf.item()
if conf < MODEL_CONFIG["conf_threshold"]:
continue
x1, y1, x2, y2 = map(int, box.xyxy[0].tolist())
class_id = int(box.cls.item())
class_name = MODEL.names[class_id]
# Draw bounding box
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
# Create label
label = f"{class_name} {conf:.2f}"
cv2.putText(frame, label, (x1, y1 - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
# Write frame to output
writer.write(frame)
processed_frames += 1
total_processed += 1
# Update progress every frame
if processed_frames % 5 == 0:
progress(processed_frames / frame_count,
desc=f"Processed {processed_frames}/200 frames")
# Calculate statistics
duration = time.time() - start_time
fps = processed_frames / duration if duration > 0 else 0
# Cleanup
cap.release()
writer.release()
os.remove(input_path)
return output_file, f"Processed {processed_frames} frames in {duration:.1f}s ({fps:.1f} FPS)"
except Exception as e:
logger.error(f"Processing failed: {e}")
return None, f"Error: {str(e)}"
# Gradio Interface
with gr.Blocks(theme=gr.themes.Soft(), title="PRISM Video Analyzer") as app:
gr.Markdown("# πŸ—οΈ PRISM Site Diary - Video Analysis (TEST MODE: 200 Frames)")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("## Video Selection")
video_select = gr.Dropdown(
label="Available Videos",
choices=list_videos(),
filterable=False
)
refresh_btn = gr.Button("πŸ”„ Refresh List", variant="secondary")
process_btn = gr.Button("πŸš€ Process First 200 Frames", variant="primary")
with gr.Column(scale=2):
gr.Markdown("## Results")
video_output = gr.Video(
label="Processed Video",
format="mp4",
interactive=False
)
status = gr.Textbox(
label="Status",
value="Ready to process first 200 frames",
interactive=False
)
def refresh_video_list():
return gr.Dropdown.update(choices=list_videos())
def handle_video_processing(blob_name):
if not blob_name:
return None, "No video selected!"
try:
local_path = download_video(blob_name)
if not local_path:
return None, "Download failed"
result, message = process_video(local_path)
return result, message
except Exception as e:
logger.error(f"Processing error: {e}")
return None, f"Error: {str(e)}"
refresh_btn.click(refresh_video_list, outputs=video_select)
process_btn.click(
handle_video_processing,
inputs=video_select,
outputs=[video_output, status],
queue=True
)
if __name__ == "__main__":
app.launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True
)