import os import streamlit as st import cv2 import numpy as np import tempfile import torch import matplotlib.pyplot as plt from PIL import Image, ImageDraw from transformers import AutoModelForObjectDetection # Fix cache permission issue os.environ['TRANSFORMERS_CACHE'] = '/tmp/huggingface' model_id = "NaveenKumar5/Solar_panel_fault_detection" @st.cache_resource def load_model(): model = AutoModelForObjectDetection.from_pretrained(model_id) return model model = load_model() model.eval() st.title("🔍 Solar Panel Fault Detection") st.write("Upload an image or video to detect faults and view heatmaps.") uploaded_file = st.file_uploader("Upload Image or Video", type=["jpg", "png", "mp4", "avi"]) def draw_boxes(image, boxes, labels, scores): draw = ImageDraw.Draw(image) for box, label, score in zip(boxes, labels, scores): draw.rectangle(box, outline="red", width=2) draw.text((box[0], box[1] - 10), f"{label}: {score:.2f}", fill="red") return image def generate_heatmap(image, boxes): heatmap = np.zeros((image.shape[0], image.shape[1]), dtype=np.float32) for box in boxes: x0, y0, x1, y1 = map(int, box) heatmap[y0:y1, x0:x1] += 1 heatmap = np.clip(heatmap / np.max(heatmap), 0, 1) return heatmap def preprocess_image(image): image = image.resize((800, 800)) image_np = np.array(image).astype(np.float32) / 255.0 image_tensor = torch.tensor(image_np).permute(2, 0, 1).unsqueeze(0) return image_tensor if uploaded_file is not None: if uploaded_file.type.startswith("image"): image = Image.open(uploaded_file).convert("RGB") inputs = preprocess_image(image) with torch.no_grad(): outputs = model(pixel_values=inputs) scores = outputs["logits"].softmax(-1)[0].max(-1).values keep = scores > 0.5 boxes = outputs["pred_boxes"][0][keep].cpu().numpy() labels = outputs["logits"].argmax(-1)[0][keep].cpu().numpy() scores = scores[keep].cpu().numpy() image_np = np.array(image) height, width = image_np.shape[:2] abs_boxes = [] for box in boxes: cx, cy, w, h = box x0 = int((cx - w / 2) * width) y0 = int((cy - h / 2) * height) x1 = int((cx + w / 2) * width) y1 = int((cy + h / 2) * height) abs_boxes.append([x0, y0, x1, y1]) # Draw boxes and labels boxed_image = draw_boxes(image.copy(), abs_boxes, labels, scores) st.image(boxed_image, caption="Detected Faults", use_column_width=True) # Generate and show heatmap heatmap = generate_heatmap(image_np, abs_boxes) fig, ax = plt.subplots() ax.imshow(image_np) ax.imshow(heatmap, cmap="jet", alpha=0.5) ax.axis("off") st.pyplot(fig) elif uploaded_file.type.startswith("video"): st.warning("Video support coming soon. For now, please upload an image.")