NaveenKumar5 commited on
Commit
bcb590b
·
verified ·
1 Parent(s): 575387f

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +64 -98
src/streamlit_app.py CHANGED
@@ -1,117 +1,83 @@
 
1
  import streamlit as st
2
- from transformers import AutoFeatureExtractor, AutoModelForObjectDetection
3
- import torch
4
- import numpy as np
5
  import cv2
6
- from PIL import Image
7
  import tempfile
8
- import os
 
 
 
 
 
 
9
 
10
- # Load model once
11
- @st.cache_resource(show_spinner=True)
 
12
  def load_model():
13
- model_id = "NaveenKumar5/Solar_panel_fault_detection"
14
  extractor = AutoFeatureExtractor.from_pretrained(model_id)
15
  model = AutoModelForObjectDetection.from_pretrained(model_id)
16
- model.eval()
17
  return extractor, model
18
 
19
  extractor, model = load_model()
 
20
 
21
- def detect_faults(image: Image.Image):
22
- inputs = extractor(images=image, return_tensors="pt")
23
- with torch.no_grad():
24
- outputs = model(**inputs)
25
- target_sizes = torch.tensor([image.size[::-1]])
26
- results = extractor.post_process_object_detection(outputs, target_sizes=target_sizes)[0]
27
-
28
- return results
29
-
30
- def draw_boxes_and_heatmap(image: Image.Image, results):
31
- image_np = np.array(image).copy()
32
 
33
- # Create heatmap mask
34
- heatmap_mask = np.zeros((image_np.shape[0], image_np.shape[1]), dtype=np.uint8)
35
 
36
- for box, score, label in zip(results["boxes"], results["scores"], results["labels"]):
37
- if score < 0.5:
38
- continue
39
- box = box.int().cpu().numpy()
40
- # Draw bounding box
41
- cv2.rectangle(image_np, (box[0], box[1]), (box[2], box[3]), (255, 0, 0), 2)
42
- # Put label text
43
- text = f"{label.item()} {score:.2f}"
44
- cv2.putText(image_np, text, (box[0], box[1]-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)
45
- # Fill heatmap mask
46
- heatmap_mask[box[1]:box[3], box[0]:box[2]] = 255
47
 
48
- # Create heatmap overlay (apply colormap on mask)
49
- heatmap_color = cv2.applyColorMap(heatmap_mask, cv2.COLORMAP_JET)
50
- heatmap_color = cv2.cvtColor(heatmap_color, cv2.COLOR_BGR2RGB)
 
 
 
 
51
 
52
- # Overlay heatmap with some transparency
53
- overlayed = cv2.addWeighted(image_np, 0.7, heatmap_color, 0.3, 0)
54
-
55
- return Image.fromarray(overlayed)
56
-
57
- def process_video(video_path):
58
- cap = cv2.VideoCapture(video_path)
59
- frames = []
60
- while True:
61
- ret, frame = cap.read()
62
- if not ret:
63
- break
64
- # Convert frame to PIL
65
- pil_frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
66
- results = detect_faults(pil_frame)
67
- frame_out = draw_boxes_and_heatmap(pil_frame, results)
68
- # Convert back to BGR for OpenCV video write
69
- frame_out = cv2.cvtColor(np.array(frame_out), cv2.COLOR_RGB2BGR)
70
- frames.append(frame_out)
71
- cap.release()
72
- return frames
73
-
74
- def save_video(frames, output_path, fps=20):
75
- height, width, _ = frames[0].shape
76
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
77
- out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
78
- for frame in frames:
79
- out.write(frame)
80
- out.release()
81
-
82
- st.title("Solar Panel Fault Detection with Heatmap")
83
-
84
- uploaded_file = st.file_uploader("Upload an image or video", type=["jpg","jpeg","png","mp4","avi"])
85
-
86
- if uploaded_file:
87
  if uploaded_file.type.startswith("image"):
88
  image = Image.open(uploaded_file).convert("RGB")
89
- st.image(image, caption="Uploaded Image", use_column_width=True)
90
- with st.spinner("Detecting faults..."):
91
- results = detect_faults(image)
92
- output_image = draw_boxes_and_heatmap(image, results)
93
- st.image(output_image, caption="Detection & Heatmap Overlay", use_column_width=True)
94
-
95
- # Show detected faults details
96
- st.subheader("Detected Faults")
97
- for box, score, label in zip(results["boxes"], results["scores"], results["labels"]):
98
- if score > 0.5:
99
- st.write(f"Fault: **{label.item()}** Confidence: {score:.2f}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
  elif uploaded_file.type.startswith("video"):
102
- tfile = tempfile.NamedTemporaryFile(delete=False)
103
- tfile.write(uploaded_file.read())
104
- video_path = tfile.name
105
-
106
- st.video(video_path)
107
- st.write("Processing video frames. This may take some time depending on video length.")
108
-
109
- with st.spinner("Detecting faults in video..."):
110
- frames = process_video(video_path)
111
-
112
- # Save output video
113
- output_path = "output.mp4"
114
- save_video(frames, output_path)
115
-
116
- st.video(output_path)
117
- os.remove(video_path)
 
1
+ import os
2
  import streamlit as st
 
 
 
3
  import cv2
4
+ import numpy as np
5
  import tempfile
6
+ import torch
7
+ import matplotlib.pyplot as plt
8
+ from transformers import AutoFeatureExtractor, AutoModelForObjectDetection
9
+ from PIL import Image, ImageDraw
10
+
11
+ # Fix cache permission issue
12
+ os.environ['TRANSFORMERS_CACHE'] = '/tmp/huggingface'
13
 
14
+ model_id = "NaveenKumar5/Solar_panel_fault_detection"
15
+
16
+ @st.cache_resource
17
  def load_model():
 
18
  extractor = AutoFeatureExtractor.from_pretrained(model_id)
19
  model = AutoModelForObjectDetection.from_pretrained(model_id)
 
20
  return extractor, model
21
 
22
  extractor, model = load_model()
23
+ model.eval()
24
 
25
+ st.title("🔍 Solar Panel Fault Detection")
26
+ st.write("Upload an image or video to detect faults and view heatmaps.")
 
 
 
 
 
 
 
 
 
27
 
28
+ uploaded_file = st.file_uploader("Upload Image or Video", type=["jpg", "png", "mp4", "avi"])
 
29
 
30
+ def draw_boxes(image, boxes, labels, scores):
31
+ draw = ImageDraw.Draw(image)
32
+ for box, label, score in zip(boxes, labels, scores):
33
+ draw.rectangle(box, outline="red", width=2)
34
+ draw.text((box[0], box[1] - 10), f"{label}: {score:.2f}", fill="red")
35
+ return image
 
 
 
 
 
36
 
37
+ def generate_heatmap(image, boxes):
38
+ heatmap = np.zeros((image.shape[0], image.shape[1]), dtype=np.float32)
39
+ for box in boxes:
40
+ x0, y0, x1, y1 = map(int, box)
41
+ heatmap[y0:y1, x0:x1] += 1
42
+ heatmap = np.clip(heatmap / np.max(heatmap), 0, 1)
43
+ return heatmap
44
 
45
+ if uploaded_file is not None:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  if uploaded_file.type.startswith("image"):
47
  image = Image.open(uploaded_file).convert("RGB")
48
+ inputs = extractor(images=image, return_tensors="pt")
49
+ with torch.no_grad():
50
+ outputs = model(**inputs)
51
+
52
+ scores = outputs["logits"].softmax(-1)[0].max(-1).values
53
+ keep = scores > 0.5
54
+
55
+ boxes = outputs["pred_boxes"][0][keep].cpu().numpy()
56
+ labels = outputs["logits"].argmax(-1)[0][keep].cpu().numpy()
57
+ scores = scores[keep].cpu().numpy()
58
+
59
+ image_np = np.array(image)
60
+ height, width = image_np.shape[:2]
61
+ abs_boxes = []
62
+ for box in boxes:
63
+ cx, cy, w, h = box
64
+ x0 = int((cx - w / 2) * width)
65
+ y0 = int((cy - h / 2) * height)
66
+ x1 = int((cx + w / 2) * width)
67
+ y1 = int((cy + h / 2) * height)
68
+ abs_boxes.append([x0, y0, x1, y1])
69
+
70
+ # Draw boxes and labels
71
+ boxed_image = draw_boxes(image.copy(), abs_boxes, labels, scores)
72
+ st.image(boxed_image, caption="Detected Faults", use_column_width=True)
73
+
74
+ # Generate and show heatmap
75
+ heatmap = generate_heatmap(image_np, abs_boxes)
76
+ fig, ax = plt.subplots()
77
+ ax.imshow(image_np)
78
+ ax.imshow(heatmap, cmap="jet", alpha=0.5)
79
+ ax.axis("off")
80
+ st.pyplot(fig)
81
 
82
  elif uploaded_file.type.startswith("video"):
83
+ st.warning("Video support coming soon. For now, please upload an image.")