Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| from transformers import DetrImageProcessor, DetrForObjectDetection | |
| import cv2 | |
| import numpy as np | |
| import tempfile | |
| import os | |
| import asyncio | |
| from concurrent.futures import ThreadPoolExecutor | |
| import warnings | |
| from transformers.utils import logging | |
| # Set page configuration | |
| st.set_page_config(page_title="Solar Panel Fault Detection", layout="wide") | |
| # Title and description | |
| st.title("Solar Panel Fault Detection PoC") | |
| st.write("Upload a thermal video (MP4) to detect thermal, dust, and power generation faults.") | |
| # UI controls for optimization parameters | |
| st.sidebar.header("Analysis Settings") | |
| frame_skip = st.sidebar.slider("Frame Skip (higher = faster, less thorough)", min_value=1, max_value=50, value=30) | |
| batch_size = st.sidebar.slider("Batch Size (adjust for hardware)", min_value=1, max_value=32, value=16 if torch.cuda.is_available() else 8) | |
| resize_enabled = st.sidebar.checkbox("Resize Frames (faster processing)", value=True) | |
| resize_width = 512 if resize_enabled else None | |
| quantize_model = st.sidebar.checkbox("Quantize Model (faster, esp. on CPU)", value=True) | |
| # Load model and processor | |
| def load_model(quantize=quantize_model): | |
| warnings.filterwarnings("ignore", message="Some weights of the model checkpoint.*were not used") | |
| logging.set_verbosity_error() | |
| try: | |
| processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50") | |
| model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| # Apply dynamic quantization if enabled | |
| if quantize and device.type == "cpu": | |
| model = torch.quantization.quantize_dynamic( | |
| model, {torch.nn.Linear}, dtype=torch.qint8 | |
| ) | |
| model.eval() | |
| return processor, model, device | |
| except Exception as e: | |
| st.error(f"Failed to load model: {str(e)}. Check internet connection or cache (~/.cache/huggingface/hub).") | |
| raise | |
| processor, model, device = load_model() | |
| # Function to resize frame | |
| def resize_frame(frame, width=None): | |
| if width is None: | |
| return frame | |
| aspect_ratio = frame.shape[1] / frame.shape[0] | |
| height = int(width / aspect_ratio) | |
| return cv2.resize(frame, (width, height), interpolation=cv2.INTER_LINEAR) | |
| # Function to process a batch of frames | |
| async def detect_faults_batch(frames, processor, model, device): | |
| try: | |
| frames = [resize_frame(frame, resize_width) for frame in frames] | |
| inputs = processor(images=frames, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| target_sizes = torch.tensor([frame.shape[:2] for frame in frames]).to(device) | |
| results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9) | |
| annotated_frames = [] | |
| all_faults = [] | |
| for frame, result in zip(frames, results): | |
| faults = {"Thermal Fault": False, "Dust Fault": False, "Power Generation Fault": False} | |
| annotated_frame = frame.copy() | |
| for score, label, box in zip(result["scores"], result["labels"], result["boxes"]): | |
| box = [int(i) for i in box.tolist()] | |
| roi = frame[box[1]:box[3], box[0]:box[2]] | |
| mean_intensity = np.mean(roi) | |
| if mean_intensity > 200: | |
| faults["Thermal Fault"] = True | |
| cv2.rectangle(annotated_frame, (box[0], box[1]), (box[2], box[3]), (255, 0, 0), 2) | |
| cv2.putText(annotated_frame, "Thermal Fault", (box[0], box[1]-10), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2) | |
| elif mean_intensity < 100: | |
| faults["Dust Fault"] = True | |
| cv2.rectangle(annotated_frame, (box[0], box[1]), (box[2], box[3]), (0, 255, 0), 2) | |
| cv2.putText(annotated_frame, "Dust Fault", (box[0], box[1]-10), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2) | |
| if faults["Thermal Fault"] or faults["Dust Fault"]: | |
| faults["Power Generation Fault"] = True | |
| annotated_frames.append(annotated_frame) | |
| all_faults.append(faults) | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return annotated_frames, all_faults | |
| except Exception as e: | |
| st.error(f"Error during fault detection: {str(e)}") | |
| return [], [] | |
| # Function to process video | |
| async def process_video(video_path, frame_skip, batch_size): | |
| try: | |
| cap = cv2.VideoCapture(video_path) | |
| if not cap.isOpened(): | |
| st.error("Error: Could not open video file.") | |
| return None, None | |
| frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| fps = int(cap.get(cv2.CAP_PROP_FPS)) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| out_width = resize_width if resize_width else frame_width | |
| out_height = int(out_width * frame_height / frame_width) if resize_width else frame_height | |
| output_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name | |
| fourcc = cv2.VideoWriter_fourcc(*"mp4v") | |
| out = cv2.VideoWriter(output_path, fourcc, fps, (out_width, out_height)) | |
| video_faults = {"Thermal Fault": False, "Dust Fault": False, "Power Generation Fault": False} | |
| frame_count = 0 | |
| frames_batch = [] | |
| processed_frames = 0 | |
| with st.spinner("Analyzing video..."): | |
| progress = st.progress(0) | |
| executor = ThreadPoolExecutor(max_workers=2) | |
| while cap.isOpened(): | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| if frame_count % frame_skip != 0: | |
| frame = resize_frame(frame, resize_width) | |
| out.write(frame) | |
| frame_count += 1 | |
| processed_frames += 1 | |
| progress.progress(min(processed_frames / total_frames, 1.0)) | |
| continue | |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| frames_batch.append(frame_rgb) | |
| if len(frames_batch) >= batch_size: | |
| annotated_frames, batch_faults = await detect_faults_batch(frames_batch, processor, model, device) | |
| for annotated_frame, faults in zip(annotated_frames, batch_faults): | |
| for fault in video_faults: | |
| video_faults[fault] |= faults[fault] | |
| annotated_frame_bgr = cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR) | |
| out.write(annotated_frame_bgr) | |
| frames_batch = [] | |
| processed_frames += batch_size | |
| progress.progress(min(processed_frames / total_frames, 1.0)) | |
| frame_count += 1 | |
| if frames_batch: | |
| annotated_frames, batch_faults = await detect_faults_batch(frames_batch, processor, model, device) | |
| for annotated_frame, faults in zip(annotated_frames, batch_faults): | |
| for fault in video_faults: | |
| video_faults[fault] |= faults[fault] | |
| annotated_frame_bgr = cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR) | |
| out.write(annotated_frame_bgr) | |
| processed_frames += len(frames_batch) | |
| progress.progress(min(processed_frames / total_frames, 1.0)) | |
| cap.release() | |
| out.release() | |
| return output_path, video_faults | |
| except Exception as e: | |
| st.error(f"Error processing video: {str(e)}") | |
| return None, None | |
| finally: | |
| if 'cap' in locals() and cap.isOpened(): | |
| cap.release() | |
| if 'out' in locals(): | |
| out.release() | |
| # File uploader | |
| uploaded_file = st.file_uploader("Upload a thermal video", type=["mp4"]) | |
| if uploaded_file is not None: | |
| try: | |
| tfile = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) | |
| tfile.write(uploaded_file.read()) | |
| tfile.close() | |
| st.video(tfile.name, format="video/mp4") | |
| # Create a new event loop for Streamlit's ScriptRunner thread | |
| loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(loop) | |
| try: | |
| output_path, video_faults = loop.run_until_complete(process_video(tfile.name, frame_skip, batch_size)) | |
| finally: | |
| loop.close() | |
| if output_path and video_faults: | |
| st.subheader("Fault Detection Results") | |
| st.video(output_path, format="video/mp4") | |
| st.write("**Detected Faults in Video:**") | |
| for fault, detected in video_faults.items(): | |
| status = "Detected" if detected else "Not Detected" | |
| color = "red" if detected else "green" | |
| st.markdown(f"- **{fault}**: <span style='color:{color}'>{status}</span>", unsafe_allow_html=True) | |
| if any(video_faults.values()): | |
| st.subheader("Recommendations") | |
| if video_faults["Thermal Fault"]: | |
| st.write("- **Thermal Fault**: Inspect for damaged components or overheating issues.") | |
| if video_faults["Dust Fault"]: | |
| st.write("- **Dust Fault**: Schedule cleaning to remove dust accumulation.") | |
| if video_faults["Power Generation Fault"]: | |
| st.write("- **Power Generation Fault**: Investigate efficiency issues due to detected faults.") | |
| else: | |
| st.write("No faults detected. The solar panel appears to be functioning normally.") | |
| if os.path.exists(output_path): | |
| os.unlink(output_path) | |
| if os.path.exists(tfile.name): | |
| os.unlink(tfile.name) | |
| except Exception as e: | |
| st.error(f"Error handling uploaded file: {str(e)}") | |
| finally: | |
| if 'tfile' in locals() and os.path.exists(tfile.name): | |
| os.unlink(tfile.name) | |
| # Footer | |
| st.markdown("---") | |
| st.write("Built with Streamlit, Hugging Face Transformers, and OpenCV for Solar Panel Fault Detection PoC") |