DSatishchandra's picture
Update app.py
59c6524 verified
import streamlit as st
import cv2
import os
import numpy as np
from datetime import datetime
import matplotlib.pyplot as plt
import time
from services.detection_service import detect_faults_solar, detect_faults_windmill
from services.anomaly_service import track_anomalies, predict_anomaly
from models.solar_model import load_solar_model
from models.windmill_model import load_windmill_model
from config.settings import VIDEO_FOLDER
from PIL import Image
import io
# Set page config as the first Streamlit command
st.set_page_config(page_title="Thermal Anomaly Monitoring Dashboard", layout="wide")
# Custom CSS for styling
st.markdown(
"""
<style>
.main-header {
text-align: center;
font-size: 24px;
font-weight: bold;
color: #333;
}
.status {
text-align: center;
font-size: 16px;
color: #333;
margin-bottom: 20px;
}
.section-title {
font-size: 16px;
font-weight: bold;
color: #333;
text-transform: uppercase;
margin-bottom: 10px;
}
.section-box {
border: 1px solid #4A90E2;
padding: 10px;
border-radius: 5px;
margin-bottom: 20px;
}
.log-entry {
font-size: 14px;
color: #333;
margin-bottom: 5px;
}
.metrics-text {
font-size: 14px;
color: #333;
margin-bottom: 5px;
}
.snapshot-img {
max-width: 100%;
height: auto;
margin-bottom: 10px;
}
</style>
""",
unsafe_allow_html=True
)
# Initialize session state
if 'paused' not in st.session_state:
st.session_state.paused = False
st.session_state.frame_rate = 0.2 # Adjusted default to reduce flicker
st.session_state.frame_count = 0
st.session_state.logs = []
st.session_state.anomaly_counts = []
st.session_state.frame_numbers = []
st.session_state.total_detected = 0
st.session_state.snapshots = []
st.session_state.last_frame = None
st.session_state.last_metrics = {}
st.session_state.last_timestamp = ""
# Create snapshots directory
SNAPSHOT_FOLDER = "./snapshots"
os.makedirs(SNAPSHOT_FOLDER, exist_ok=True)
# Function to resize and pad image to 640x640 while preserving aspect ratio
def preprocess_image(image, target_size=(640, 640)):
h, w = image.shape[:2]
target_h, target_w = target_size
# Calculate scaling factor to maintain aspect ratio
scale = min(target_w / w, target_h / h)
new_w, new_h = int(w * scale), int(h * scale)
# Resize image
resized_image = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_AREA)
# Create a blank 640x640 image
padded_image = np.zeros((target_h, target_w, 3), dtype=np.uint8)
# Calculate padding offsets
top = (target_h - new_h) // 2
left = (target_w - new_w) // 2
# Place resized image in the center
padded_image[top:top + new_h, left:left + new_w] = resized_image
return padded_image
# Core monitor function
def monitor_feed(video_path, detection_type, model):
if st.session_state.paused and st.session_state.last_frame is not None:
frame = st.session_state.last_frame.copy()
metrics = st.session_state.last_metrics.copy()
num_anomalies = len(metrics.get('anomalies', []))
snapshot_paths = [snapshot["path"] for snapshot in st.session_state.snapshots]
else:
cap = cv2.VideoCapture(video_path)
cap.set(cv2.CAP_PROP_POS_FRAMES, st.session_state.frame_count)
ret, frame = cap.read()
cap.release()
if not ret:
return None, None, None, None, None
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# Preprocess frame to 640x640
frame_rgb = preprocess_image(frame_rgb, target_size=(640, 640))
faults = detect_faults_solar(model, frame_rgb) if detection_type == "Solar Panel" else detect_faults_windmill(model, frame_rgb)
num_anomalies = len(faults)
# Draw bounding boxes and labels
annotated_frame = frame_rgb.copy()
for fault in faults:
x, y = int(fault['location'][0]), int(fault['location'][1])
cv2.rectangle(annotated_frame, (x-30, y-30), (x+30, y+30), (255, 0, 0), 2)
cv2.putText(annotated_frame, f"{fault['type']}", (x, y-40),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0), 2)
# Save snapshot if faults detected
snapshot_paths = [snapshot["path"] for snapshot in st.session_state.snapshots]
if faults:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
fault_types = "_".join([fault['type'].replace(" ", "_") for fault in faults])
snapshot_filename = f"snapshot_{timestamp}_frame_{st.session_state.frame_count}_{fault_types}.png"
snapshot_path = os.path.join(SNAPSHOT_FOLDER, snapshot_filename)
cv2.imwrite(snapshot_path, cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR))
st.session_state.snapshots.append({
"path": snapshot_path,
"log": f"{timestamp} - Frame {st.session_state.frame_count} - Anomalies: {num_anomalies} ({fault_types})"
})
if len(st.session_state.snapshots) > 5:
st.session_state.snapshots.pop(0)
snapshot_paths = [snapshot["path"] for snapshot in st.session_state.snapshots]
st.session_state.frame_count += 1
st.session_state.last_timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
st.session_state.last_frame = annotated_frame.copy()
# Update metrics
metrics = {"anomalies": faults, "total_detected": st.session_state.total_detected + num_anomalies}
st.session_state.last_metrics = metrics.copy()
st.session_state.total_detected += num_anomalies
# Update logs and anomaly counts
log_entry = f"{st.session_state.last_timestamp} - Frame {st.session_state.frame_count} - Anomalies: {num_anomalies}"
st.session_state.logs.append(log_entry)
st.session_state.anomaly_counts.append(num_anomalies)
st.session_state.frame_numbers.append(st.session_state.frame_count)
if len(st.session_state.logs) > 100:
st.session_state.logs.pop(0)
st.session_state.anomaly_counts.pop(0)
st.session_state.frame_numbers.pop(0)
frame = annotated_frame
# Add frame count and timestamp to frame
frame = cv2.resize(frame, (640, 480))
cv2.putText(frame, f"Frame: {st.session_state.frame_count}", (10, 25), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
cv2.putText(frame, f"{st.session_state.last_timestamp}", (10, 50), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
metrics_str = "\n".join([f"{k}: {v}" for k, v in st.session_state.last_metrics.items()])
logs_str = "\n".join(st.session_state.logs[-10:])
# Generate chart
fig, ax = plt.subplots(figsize=(6, 3))
ax.plot(st.session_state.frame_numbers[-50:], st.session_state.anomaly_counts[-50:], marker='o', color='blue')
ax.set_xlabel("Frame", fontsize=10)
ax.set_ylabel("Count", fontsize=10)
ax.grid(True)
ax.tick_params(axis='both', which='major', labelsize=8)
chart_path = "chart_temp.png"
fig.savefig(chart_path)
plt.close(fig)
return frame, metrics_str, logs_str, chart_path, snapshot_paths
def main():
st.markdown('<div class="main-header">THERMAL ANOMALY MONITORING DASHBOARD</div>', unsafe_allow_html=True)
# Status banner
status_placeholder = st.markdown(
f'<div class="status">{"🟢 Running" if not st.session_state.paused else "⏸️ Paused"}</div>',
unsafe_allow_html=True
)
# Sidebar for video selection and detection type
st.sidebar.header("Settings")
video_files = [f for f in os.listdir(VIDEO_FOLDER) if f.endswith('.mp4')]
if not video_files:
st.error("No videos found in the 'data' folder. Please add .mp4 files.")
return
video_file = st.sidebar.selectbox("Select Video", video_files)
detection_type = st.sidebar.selectbox("Detection Type", ["Solar Panel", "Windmill"])
# Load the appropriate model
model = load_solar_model() if detection_type == "Solar Panel" else load_windmill_model()
# Layout: Video feed and metrics
col1, col2 = st.columns([3, 1])
with col1:
st.markdown('<div class="section-title">LIVE VIDEO FEED</div>', unsafe_allow_html=True)
with st.container():
st.markdown('<div class="section-box">', unsafe_allow_html=True)
video_placeholder = st.empty()
st.markdown('</div>', unsafe_allow_html=True)
with col2:
st.markdown('<div class="section-title">LIVE METRICS</div>', unsafe_allow_html=True)
with st.container():
st.markdown('<div class="section-box">', unsafe_allow_html=True)
metrics_placeholder = st.empty()
prediction_placeholder = st.empty()
st.markdown('</div>', unsafe_allow_html=True)
# Layout: Logs and trends
col3, col4 = st.columns([1, 2])
with col3:
st.markdown('<div class="section-title">LIVE LOGS</div>', unsafe_allow_html=True)
with st.container():
st.markdown('<div class="section-box">', unsafe_allow_html=True)
logs_placeholder = st.empty()
st.markdown('</div>', unsafe_allow_html=True)
st.markdown('<div class="section-title">LAST 5 CAPTURED EVENTS</div>', unsafe_allow_html=True)
with st.container():
st.markdown('<div class="section-box">', unsafe_allow_html=True)
gallery_placeholder = st.empty()
st.markdown('</div>', unsafe_allow_html=True)
with col4:
st.markdown('<div class="section-title">DETECTION TRENDS</div>', unsafe_allow_html=True)
with st.container():
st.markdown('<div class="section-box">', unsafe_allow_html=True)
st.markdown('<div style="font-size: 14px; font-weight: bold; margin-bottom: 10px;">Anomalies Over Time</div>', unsafe_allow_html=True)
trends_placeholder = st.empty()
st.markdown('</div>', unsafe_allow_html=True)
# Controls
with st.container():
col5, col6, col7 = st.columns([1, 1, 2])
with col5:
pause_btn = st.button("⏸️ Pause")
with col6:
resume_btn = st.button("▶️ Resume")
with col7:
frame_rate = st.slider("Frame Interval (seconds)", 0.1, 1.0, st.session_state.frame_rate) # Adjusted range
# Handle button clicks and slider
if pause_btn:
st.session_state.paused = True
status_placeholder.markdown('<div class="status">⏸️ Paused</div>', unsafe_allow_html=True)
if resume_btn:
st.session_state.paused = False
status_placeholder.markdown('<div class="status">🟢 Running</div>', unsafe_allow_html=True)
st.session_state.frame_rate = frame_rate
# Streaming loop
video_path = os.path.join(VIDEO_FOLDER, video_file)
while True:
if st.session_state.paused:
# Skip processing if paused and last frame exists
if st.session_state.last_frame is not None:
frame, metrics, logs, chart, snapshots = monitor_feed(video_path, detection_type, model)
if frame is None:
st.success("Video processing completed.")
break
# Update UI only if necessary
video_placeholder.image(frame, channels="RGB", width=640)
metrics_formatted = metrics.replace("\n", "<br>")
metrics_placeholder.markdown(
f'<div class="metrics-text">{metrics_formatted}</div>',
unsafe_allow_html=True
)
logs_formatted = logs.replace("\n", "</div><div class='log-entry'>")
logs_placeholder.markdown(
f'<div class="log-entry">{logs_formatted}</div>',
unsafe_allow_html=True
)
trends_placeholder.image(chart)
with gallery_placeholder.container():
cols = st.columns(5)
for i, col in enumerate(cols):
with col:
if i < len(snapshots):
st.image(snapshots[i], width=100)
st.markdown(f'<div class="log-entry">{st.session_state.snapshots[i]["log"]}</div>', unsafe_allow_html=True)
else:
st.empty()
prediction = predict_anomaly(st.session_state.anomaly_counts)
if prediction:
prediction_placeholder.warning("**Prediction:** Potential issue detected - anomaly spike detected!")
else:
prediction_placeholder.empty()
time.sleep(st.session_state.frame_rate)
continue # Avoid rerun when paused
else:
frame, metrics, logs, chart, snapshots = monitor_feed(video_path, detection_type, model)
if frame is None:
st.success("Video processing completed.")
break
else:
frame, metrics, logs, chart, snapshots = monitor_feed(video_path, detection_type, model)
if frame is None:
st.success("Video processing completed.")
break
# Update UI
video_placeholder.image(frame, channels="RGB", width=640)
metrics_formatted = metrics.replace("\n", "<br>")
metrics_placeholder.markdown(
f'<div class="metrics-text">{metrics_formatted}</div>',
unsafe_allow_html=True
)
logs_formatted = logs.replace("\n", "</div><div class='log-entry'>")
logs_placeholder.markdown(
f'<div class="log-entry">{logs_formatted}</div>',
unsafe_allow_html=True
)
trends_placeholder.image(chart)
# Update gallery
with gallery_placeholder.container():
cols = st.columns(5)
for i, col in enumerate(cols):
with col:
if i < len(snapshots):
st.image(snapshots[i], width=100)
st.markdown(f'<div class="log-entry">{st.session_state.snapshots[i]["log"]}</div>', unsafe_allow_html=True)
else:
st.empty()
# Predictive anomaly detection
prediction = predict_anomaly(st.session_state.anomaly_counts)
if prediction:
prediction_placeholder.warning("**Prediction:** Potential issue detected - anomaly spike detected!")
else:
prediction_placeholder.empty()
time.sleep(st.session_state.frame_rate)
st.rerun()
if __name__ == "__main__":
main()