Testing_p / src /streamlit_app.py
NaveenKumar5's picture
Update src/streamlit_app.py
3d8126c verified
import os
import streamlit as st
import cv2
import numpy as np
import tempfile
import torch
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw
from transformers import AutoModelForObjectDetection
# Fix cache permission issue
os.environ['TRANSFORMERS_CACHE'] = '/tmp/huggingface'
model_id = "NaveenKumar5/Solar_panel_fault_detection"
@st.cache_resource
def load_model():
model = AutoModelForObjectDetection.from_pretrained(model_id)
return model
model = load_model()
model.eval()
st.title("🔍 Solar Panel Fault Detection")
st.write("Upload an image or video to detect faults and view heatmaps.")
uploaded_file = st.file_uploader("Upload Image or Video", type=["jpg", "png", "mp4", "avi"])
def draw_boxes(image, boxes, labels, scores):
draw = ImageDraw.Draw(image)
for box, label, score in zip(boxes, labels, scores):
draw.rectangle(box, outline="red", width=2)
draw.text((box[0], box[1] - 10), f"{label}: {score:.2f}", fill="red")
return image
def generate_heatmap(image, boxes):
heatmap = np.zeros((image.shape[0], image.shape[1]), dtype=np.float32)
for box in boxes:
x0, y0, x1, y1 = map(int, box)
heatmap[y0:y1, x0:x1] += 1
heatmap = np.clip(heatmap / np.max(heatmap), 0, 1)
return heatmap
def preprocess_image(image):
image = image.resize((800, 800))
image_np = np.array(image).astype(np.float32) / 255.0
image_tensor = torch.tensor(image_np).permute(2, 0, 1).unsqueeze(0)
return image_tensor
if uploaded_file is not None:
if uploaded_file.type.startswith("image"):
image = Image.open(uploaded_file).convert("RGB")
inputs = preprocess_image(image)
with torch.no_grad():
outputs = model(pixel_values=inputs)
scores = outputs["logits"].softmax(-1)[0].max(-1).values
keep = scores > 0.5
boxes = outputs["pred_boxes"][0][keep].cpu().numpy()
labels = outputs["logits"].argmax(-1)[0][keep].cpu().numpy()
scores = scores[keep].cpu().numpy()
image_np = np.array(image)
height, width = image_np.shape[:2]
abs_boxes = []
for box in boxes:
cx, cy, w, h = box
x0 = int((cx - w / 2) * width)
y0 = int((cy - h / 2) * height)
x1 = int((cx + w / 2) * width)
y1 = int((cy + h / 2) * height)
abs_boxes.append([x0, y0, x1, y1])
# Draw boxes and labels
boxed_image = draw_boxes(image.copy(), abs_boxes, labels, scores)
st.image(boxed_image, caption="Detected Faults", use_column_width=True)
# Generate and show heatmap
heatmap = generate_heatmap(image_np, abs_boxes)
fig, ax = plt.subplots()
ax.imshow(image_np)
ax.imshow(heatmap, cmap="jet", alpha=0.5)
ax.axis("off")
st.pyplot(fig)
elif uploaded_file.type.startswith("video"):
st.warning("Video support coming soon. For now, please upload an image.")