Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from ultralytics import YOLO | |
| import cv2 | |
| import tempfile | |
| import os | |
| import base64 | |
| import numpy as np | |
| import time | |
| import csv | |
| import zipfile | |
| from io import BytesIO, StringIO | |
| from datetime import datetime | |
| st.set_page_config(page_title="Flood Level Detection", layout="wide") | |
| st.markdown(""" | |
| <style> | |
| body { | |
| background: linear-gradient(to right, #e3f2fd 0%, #ffffff 25%, #ffffff 75%, #e3f2fd 100%); | |
| } | |
| .stApp { | |
| background: linear-gradient(to right, #e3f2fd 0%, #ffffff 25%, #ffffff 75%, #e3f2fd 100%); | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| st.markdown(""" | |
| <style> | |
| .block-container {padding-top: 2rem; padding-bottom: 8rem;} | |
| .main-title {font-size:34px;font-weight:bold;text-align:center;color:#004080;margin-top:40px;margin-bottom:35px;} | |
| .settings-box {border:2px solid #363738;border-radius:10px;padding:15px;background-color:#E8E8E8;} | |
| .yellow-header {background-color:#363738;color:white;font-weight:bold;text-align:center;padding:6px;border-radius:6px;margin-bottom:10px;} | |
| .centered-status {display:flex;justify-content:center;align-items:center;margin-top:15px;} | |
| .status-text {text-align:center;font-size:16px;font-weight:bold;} | |
| .progress-container {display:flex;justify-content:center;align-items:center;width:100%;margin-top:15px;} | |
| .progress-bar-wrapper {width:60%;} | |
| div[data-testid="stButton"] > button { | |
| background: linear-gradient(to right, #e3f2fd 0%, #ffffff 25%, #ffffff 75%, #e3f2fd 100%) !important; | |
| color: #004080 !important; | |
| font-weight: bold !important; | |
| border: 2px solid #004080 !important; | |
| border-radius: 8px !important; | |
| padding: 10px 20px !important; | |
| transition: all 0.3s ease !important; | |
| } | |
| div[data-testid="stButton"] > button:hover { | |
| background: linear-gradient(to right, #bbdefb 0%, #e3f2fd 25%, #e3f2fd 75%, #bbdefb 100%) !important; | |
| border-color: #0d47a1 !important; | |
| box-shadow: 0 4px 12px rgba(13, 71, 161, 0.2) !important; | |
| } | |
| div[data-testid="stDownloadButton"] > button { | |
| background: linear-gradient(to right, #e3f2fd 0%, #ffffff 25%, #ffffff 75%, #e3f2fd 100%) !important; | |
| color: #004080 !important; | |
| font-weight: bold !important; | |
| border: 2px solid #004080 !important; | |
| border-radius: 8px !important; | |
| padding: 10px 20px !important; | |
| transition: all 0.3s ease !important; | |
| } | |
| div[data-testid="stDownloadButton"] > button:hover { | |
| background: linear-gradient(to right, #bbdefb 0%, #e3f2fd 25%, #e3f2fd 75%, #bbdefb 100%) !important; | |
| border-color: #0d47a1 !important; | |
| box-shadow: 0 4px 12px rgba(13, 71, 161, 0.2) !important; | |
| } | |
| div[data-testid="stNumberInput"] label { | |
| font-size: 20px !important; | |
| font-weight: bold !important; | |
| color: #004080 !important; | |
| } | |
| div[data-testid="stNumberInput"] input { | |
| font-size: 24px !important; | |
| font-weight: bold !important; | |
| height: 50px !important; | |
| } | |
| div[data-testid="stRadio"] label { | |
| font-size: 18px !important; | |
| font-weight: bold !important; | |
| color: #004080 !important; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| st.markdown(""" | |
| <style> | |
| .header-container { | |
| display: flex; | |
| justify-content: center; | |
| align-items: center; | |
| position: relative; | |
| margin-top: 30px; | |
| margin-bottom: 25px; | |
| background: transparent; | |
| padding: 0px 0; | |
| border: none; | |
| } | |
| .header-title { | |
| font-size: 36px; | |
| font-weight: bold; | |
| color: #004080; | |
| text-align: center; | |
| } | |
| .header-logo { | |
| position: absolute; | |
| right: 80px; | |
| top: 50%; | |
| transform: translateY(-50%); | |
| } | |
| .header-logo img { | |
| height: 70px; | |
| width: auto; | |
| border-radius: 8px; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # --- Load logo --- | |
| with open("assets/logo3u.png", "rb") as img_file: | |
| logo_data = base64.b64encode(img_file.read()).decode() | |
| # --- Header layout --- | |
| st.markdown(f""" | |
| <style> | |
| .header-logo img {{ | |
| height: 120px; /* increase this value to make the image bigger */ | |
| }} | |
| </style> | |
| <div class="header-container"> | |
| <div class="header-title">FLOOD-DEPTH-ML</div> | |
| <div class="header-logo"> | |
| <img src="data:image/png;base64,{logo_data}"> | |
| </div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # --------------------------- | |
| # LOAD MODEL | |
| # --------------------------- | |
| MODEL_PATH = "best_car.pt" | |
| if not os.path.exists(MODEL_PATH): | |
| st.error("❌ Model file not found! Please place your model as 'best_car.pt'") | |
| st.stop() | |
| model = YOLO(MODEL_PATH) | |
| # --------------------------- | |
| # LAYOUT | |
| # --------------------------- | |
| col1, col2, col3 = st.columns([1.2, 2.6, 1.2]) | |
| # --------------------------- | |
| # LEFT PANEL | |
| # --------------------------- | |
| with col1: | |
| st.subheader("📥 Input Source") | |
| input_type = st.radio("Choose Input Type", ["Upload File"]) | |
| uploaded_file = None | |
| if input_type == "Upload File": | |
| uploaded_file = st.file_uploader("Upload Image or Video", type=["jpg", "png", "mp4", "avi"]) | |
| analyze_btn = st.button("🔍 Analyze", width='stretch') | |
| download_area = st.empty() | |
| # --------------------------- | |
| # CENTER PANEL | |
| # --------------------------- | |
| with col2: | |
| st.subheader("🎥 Detection Display") | |
| display_area = st.empty() | |
| controls_area = st.container() | |
| status_area = st.empty() | |
| # --------------------------- | |
| # RIGHT PANEL (settings) | |
| # --------------------------- | |
| with col3: | |
| st.markdown("<div class='settings-box'>", unsafe_allow_html=True) | |
| st.markdown("<div class='yellow-header'>⚙️ SETTINGS</div>", unsafe_allow_html=True) | |
| default_skip = 1 | |
| try: | |
| raw_input = st.number_input( | |
| "⏱️ Analyze every Nth frame (1 = all, 10 = skip 10 frames)", | |
| min_value=1, | |
| max_value=60, | |
| value=default_skip, | |
| step=1, | |
| key="frame_skip" | |
| ) | |
| fps_input = int(raw_input) | |
| except Exception: | |
| fps_input = default_skip | |
| st.markdown("### 🌊 Flood Levels") | |
| level_placeholders = {f"Level {i}": st.empty() for i in range(5)} | |
| for level in level_placeholders: | |
| level_placeholders[level].markdown( | |
| f"<div style='text-align:left;font-size:20px;font-weight:bold;margin:4px;'>{level}: 0</div>", | |
| unsafe_allow_html=True | |
| ) | |
| st.markdown("### 🧾 Labelling Criteria") | |
| if os.path.exists("assets/scheme.png"): | |
| st.image("assets/scheme.png", caption="Reference Criteria", width='stretch') | |
| st.markdown("</div>", unsafe_allow_html=True) | |
| # --------------------------- | |
| # FOOTER | |
| # --------------------------- | |
| def show_footer_logos(): | |
| logo1 = "assets/logo1.png" | |
| if not (os.path.exists(logo1)): | |
| return | |
| with open(logo1, "rb") as f: | |
| a = base64.b64encode(f.read()).decode() | |
| st.markdown(f""" | |
| <div style="position:fixed;left:0;bottom:0;width:100%;background-color:white;display:flex;justify-content:center;align-items:center;gap:10px;padding:5px 0;box-shadow:0 -2px 8px rgba(0,0,0,0.15);z-index:999;"> | |
| <img src="data:image/png;base64,{a}" style="height:70px;"> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| show_footer_logos() | |
| # --------------------------- | |
| # HELPERS | |
| # --------------------------- | |
| def update_levels(counts, high_danger=False): | |
| for i, key in enumerate(level_placeholders.keys()): | |
| color = "#FF0000" if high_danger and key in ["Level 3", "Level 4"] else "#004080" | |
| level_placeholders[key].markdown( | |
| f"<div style='text-align:left;font-size:18px;font-weight:bold;margin:4px;color:{color};'>{key}: {counts.get(key, 0)}</div>", | |
| unsafe_allow_html=True | |
| ) | |
| if "pause_loop_counter" not in st.session_state: | |
| st.session_state.pause_loop_counter = 0 | |
| def detect_and_count(frame): | |
| if frame is None: | |
| return None, {f"Level {i}": 0 for i in range(5)}, False | |
| results = model(frame) | |
| annotated = results[0].plot() | |
| level_counts = {f"Level {i}": 0 for i in range(5)} | |
| high_danger = False | |
| try: | |
| for box in results[0].boxes: | |
| cls = int(box.cls[0]) | |
| key = f"Level {cls}" | |
| if key in level_counts: | |
| level_counts[key] += 1 | |
| if cls in [3, 4]: | |
| high_danger = True | |
| except Exception: | |
| pass | |
| return annotated, level_counts, high_danger | |
| # --------------------------- | |
| # SESSION STATE INIT | |
| # --------------------------- | |
| if "processing" not in st.session_state: | |
| st.session_state.processing = False | |
| if "tmp_video_path" not in st.session_state: | |
| st.session_state.tmp_video_path = None | |
| if "paused" not in st.session_state: | |
| st.session_state.paused = False | |
| if "zip_ready" not in st.session_state: | |
| st.session_state.zip_ready = False | |
| if "zip_data" not in st.session_state: | |
| st.session_state.zip_data = None | |
| if "zip_filename" not in st.session_state: | |
| st.session_state.zip_filename = None | |
| if "webcam_active" not in st.session_state: | |
| st.session_state.webcam_active = False | |
| if "webcam_cap" not in st.session_state: | |
| st.session_state.webcam_cap = None | |
| if "webcam_frames" not in st.session_state: | |
| st.session_state.webcam_frames = [] | |
| if "last_webcam_frame" not in st.session_state: | |
| st.session_state.last_webcam_frame = None | |
| if "last_webcam_counts" not in st.session_state: | |
| st.session_state.last_webcam_counts = {} | |
| if "report_log" not in st.session_state: | |
| st.session_state.report_log = [] | |
| if "webcam_frame_counter" not in st.session_state: | |
| st.session_state.webcam_frame_counter = 0 | |
| if "last_frame_bytes" not in st.session_state: | |
| st.session_state.last_frame_bytes = None | |
| # --------------------------- | |
| # UPLOAD FILE PROCESSING | |
| # --------------------------- | |
| if input_type == "Upload File" and analyze_btn and uploaded_file: | |
| ext = uploaded_file.name.split('.')[-1].lower() | |
| # Save file to temp | |
| tfile = tempfile.NamedTemporaryFile(delete=False, suffix="." + ext) | |
| tfile.write(uploaded_file.read()) | |
| tfile.close() | |
| st.session_state.tmp_video_path = tfile.name | |
| st.session_state.processing = True | |
| st.session_state.paused = False | |
| st.session_state.zip_ready = False | |
| st.session_state.report_log = [] | |
| # Process uploaded file | |
| if st.session_state.processing and st.session_state.tmp_video_path: | |
| path = st.session_state.tmp_video_path | |
| ext = path.split('.')[-1].lower() | |
| # IMAGE PROCESSING | |
| if ext in ["jpg", "jpeg", "png"]: | |
| frame = cv2.imread(path) | |
| if frame is None: | |
| st.error("Could not read image file") | |
| else: | |
| annotated, counts, high_danger = detect_and_count(frame) | |
| update_levels(counts, high_danger) | |
| display_area.image(annotated, channels="BGR", caption="Detection Result", width='stretch') | |
| if high_danger: | |
| status_area.error("🚨 HIGH DANGER DETECTED!") | |
| _, enc = cv2.imencode('.jpg', annotated) | |
| with download_area: | |
| st.download_button("⬇️ Download Detected Image", enc.tobytes(), | |
| "detected_image.jpg", "image/jpeg", width='stretch') | |
| st.session_state.processing = False | |
| # VIDEO PROCESSING | |
| elif ext in ["mp4", "avi", "mov"]: | |
| cap = cv2.VideoCapture(path) | |
| if not cap.isOpened(): | |
| st.error("Could not open video file") | |
| st.session_state.processing = False | |
| else: | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| fps = max(int(cap.get(cv2.CAP_PROP_FPS)), 20) | |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| frame_skip = st.session_state.get("frame_skip", 1) | |
| output_fps = max(1, int(fps / frame_skip)) | |
| output_path = os.path.join(tempfile.gettempdir(), f"flood_output_{timestamp}.mp4") | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| video_writer = cv2.VideoWriter(output_path, fourcc, output_fps, (width, height)) | |
| if not video_writer.isOpened(): | |
| st.error("❌ VideoWriter failed to open — check codec or path.") | |
| st.stop() | |
| frame_idx = 0 | |
| processed_count = 0 | |
| total_to_process = max(1, total_frames // frame_skip) | |
| last_frame_bytes = None | |
| last_counts, last_danger = {}, False | |
| # Create controls and placeholders BEFORE the loop | |
| with controls_area: | |
| colA, colB, colC = st.columns(3) | |
| with colA: | |
| pause_btn = st.button("⏸️ Pause", key="pause_btn", use_container_width=True) | |
| with colB: | |
| resume_btn = st.button("▶️ Resume", key="resume_btn", use_container_width=True) | |
| with colC: | |
| download_placeholder = st.empty() | |
| if pause_btn: | |
| st.session_state.paused = True | |
| st.session_state.pause_loop_counter = 0 | |
| if resume_btn: | |
| st.session_state.paused = False | |
| st.session_state.pause_loop_counter = 0 | |
| # Create centered progress bar | |
| prog_col1, prog_col2, prog_col3 = st.columns([0.5, 3, 0.5]) | |
| with prog_col2: | |
| progress_bar = st.progress(0.0) | |
| # Initialize storage for last frame | |
| if 'last_frame_bytes' not in st.session_state: | |
| st.session_state.last_frame_bytes = None | |
| st.session_state.last_counts = {} | |
| st.session_state.last_danger = False | |
| # --- Main Loop --- | |
| while cap.isOpened() and frame_idx < total_frames: | |
| # Save last processed frame info to session state | |
| if last_frame_bytes: | |
| st.session_state.last_frame_bytes = last_frame_bytes | |
| st.session_state.last_counts = last_counts | |
| st.session_state.last_danger = last_danger | |
| # --- Improved Pause Handling --- | |
| if st.session_state.paused: | |
| if st.session_state.last_frame_bytes: | |
| np_img = np.frombuffer(st.session_state.last_frame_bytes, np.uint8) | |
| paused_frame = cv2.imdecode(np_img, cv2.IMREAD_COLOR) | |
| display_area.image(paused_frame, channels="BGR", width='stretch') | |
| update_levels(st.session_state.last_counts, st.session_state.last_danger) | |
| status_area.info("⏸️ Video Paused ") | |
| with download_placeholder.container(): | |
| st.download_button( | |
| "⬇️ Download Paused Frame", | |
| st.session_state.last_frame_bytes, | |
| file_name=f"paused_frame_{frame_idx}.jpg", | |
| mime="image/jpeg", | |
| key=f"paused_dl_{st.session_state.pause_loop_counter}", | |
| use_container_width=True | |
| ) | |
| st.session_state.pause_loop_counter += 1 | |
| else: | |
| status_area.info("No frame available yet.") | |
| time.sleep(0.3) | |
| continue | |
| else: | |
| # Show download button during processing | |
| if st.session_state.last_frame_bytes: | |
| with download_placeholder.container(): | |
| st.download_button( | |
| "⬇️ Download Current Frame", | |
| st.session_state.last_frame_bytes, | |
| file_name=f"frame_{datetime.now().strftime('%Y%m%d_%H%M%S')}.jpg", | |
| mime="image/jpeg", | |
| key=f"current_dl_{processed_count}", | |
| use_container_width=True | |
| ) | |
| else: | |
| download_placeholder.empty() | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| annotated, counts, high_danger = detect_and_count(frame) | |
| if annotated is not None: | |
| update_levels(counts, high_danger) | |
| display_area.image(annotated, channels="BGR", width='stretch') | |
| _, enc = cv2.imencode('.jpg', annotated) | |
| last_frame_bytes = enc.tobytes() | |
| last_counts, last_danger = counts, high_danger | |
| video_writer.write(annotated) | |
| processed_count += 1 | |
| with prog_col2: | |
| progress_bar.progress(min(processed_count / total_to_process, 1.0)) | |
| if high_danger: | |
| status_area.markdown(f"<div class='centered-status'><div class='status-text' style='color:#FF0000;'>🚨 HIGH DANGER! Frame {processed_count}/{total_to_process}</div></div>", unsafe_allow_html=True) | |
| else: | |
| status_area.markdown(f"<div class='centered-status'><div class='status-text'>▶️ Processing frame {processed_count}/{total_to_process}</div></div>", unsafe_allow_html=True) | |
| st.session_state.report_log.append([ | |
| frame_idx + 1, | |
| counts.get("Level 0", 0), | |
| counts.get("Level 1", 0), | |
| counts.get("Level 2", 0), | |
| counts.get("Level 3", 0), | |
| counts.get("Level 4", 0) | |
| ]) | |
| frame_idx += frame_skip | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx) | |
| time.sleep(0.03) | |
| cap.release() | |
| video_writer.release() | |
| status_area.success("✅ Video processing finished!") | |
| st.session_state.processing = False | |
| # --- ZIP Export --- | |
| zip_buffer = BytesIO() | |
| with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zipf: | |
| if os.path.exists(output_path): | |
| zipf.write(output_path, arcname="detected_video.mp4") | |
| csv_stream = StringIO() | |
| csv_writer = csv.writer(csv_stream) | |
| csv_writer.writerow(["Frame No", "Level 0", "Level 1", "Level 2", "Level 3", "Level 4"]) | |
| csv_writer.writerows(st.session_state.report_log) | |
| zipf.writestr("flood_level_report.csv", csv_stream.getvalue()) | |
| zip_buffer.seek(0) | |
| st.session_state.zip_data = zip_buffer.getvalue() | |
| st.session_state.zip_filename = f"flood_analysis_{timestamp}.zip" | |
| st.session_state.zip_ready = True | |
| try: | |
| if os.path.exists(output_path): | |
| os.remove(output_path) | |
| if os.path.exists(st.session_state.tmp_video_path): | |
| os.remove(st.session_state.tmp_video_path) | |
| except Exception: | |
| pass | |
| # --- ZIP DOWNLOAD (left panel) --- | |
| if st.session_state.zip_ready: | |
| with download_area: | |
| st.download_button( | |
| label="📦 Download ZIP (Video + Report)", | |
| data=st.session_state.zip_data, | |
| file_name=st.session_state.zip_filename, | |
| mime="application/zip", | |
| key="zip_download_button", | |
| use_container_width=True | |
| ) |