Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
import streamlit as st
|
| 2 |
import torch
|
| 3 |
-
from transformers import
|
| 4 |
import cv2
|
| 5 |
import numpy as np
|
| 6 |
import tempfile
|
|
@@ -19,23 +19,35 @@ st.write("Upload a thermal video (MP4) to detect thermal, dust, and power genera
|
|
| 19 |
|
| 20 |
# UI controls for optimization parameters
|
| 21 |
st.sidebar.header("Analysis Settings")
|
| 22 |
-
frame_skip = st.sidebar.slider("Frame Skip (higher = faster, less thorough)", min_value=1, max_value=
|
| 23 |
-
batch_size = st.sidebar.slider("Batch Size (adjust for hardware)", min_value=1, max_value=16, value=
|
| 24 |
resize_enabled = st.sidebar.checkbox("Resize Frames (faster processing)", value=True)
|
| 25 |
resize_width = 640 if resize_enabled else None
|
|
|
|
| 26 |
|
| 27 |
# Load model and processor
|
| 28 |
@st.cache_resource
|
| 29 |
-
def load_model():
|
| 30 |
warnings.filterwarnings("ignore", message="Some weights of the model checkpoint.*were not used")
|
| 31 |
logging.set_verbosity_error()
|
| 32 |
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
processor, model, device = load_model()
|
| 41 |
|
|
@@ -50,7 +62,6 @@ def resize_frame(frame, width=None):
|
|
| 50 |
# Function to process a batch of frames
|
| 51 |
async def detect_faults_batch(frames, processor, model, device):
|
| 52 |
try:
|
| 53 |
-
# Resize frames if enabled
|
| 54 |
frames = [resize_frame(frame, resize_width) for frame in frames]
|
| 55 |
inputs = processor(images=frames, return_tensors="pt").to(device)
|
| 56 |
with torch.no_grad():
|
|
@@ -87,7 +98,6 @@ async def detect_faults_batch(frames, processor, model, device):
|
|
| 87 |
annotated_frames.append(annotated_frame)
|
| 88 |
all_faults.append(faults)
|
| 89 |
|
| 90 |
-
# Clear GPU memory
|
| 91 |
if torch.cuda.is_available():
|
| 92 |
torch.cuda.empty_cache()
|
| 93 |
|
|
@@ -109,7 +119,6 @@ async def process_video(video_path, frame_skip, batch_size):
|
|
| 109 |
fps = int(cap.get(cv2.CAP_PROP_FPS))
|
| 110 |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 111 |
|
| 112 |
-
# Adjust output size if resizing
|
| 113 |
out_width = resize_width if resize_width else frame_width
|
| 114 |
out_height = int(out_width * frame_height / frame_width) if resize_width else frame_height
|
| 115 |
|
|
@@ -132,7 +141,6 @@ async def process_video(video_path, frame_skip, batch_size):
|
|
| 132 |
break
|
| 133 |
|
| 134 |
if frame_count % frame_skip != 0:
|
| 135 |
-
# Resize frame for output if needed
|
| 136 |
frame = resize_frame(frame, resize_width)
|
| 137 |
out.write(frame)
|
| 138 |
frame_count += 1
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
import torch
|
| 3 |
+
from transformers import YolosImageProcessor, YolosForObjectDetection
|
| 4 |
import cv2
|
| 5 |
import numpy as np
|
| 6 |
import tempfile
|
|
|
|
| 19 |
|
| 20 |
# UI controls for optimization parameters
|
| 21 |
st.sidebar.header("Analysis Settings")
|
| 22 |
+
frame_skip = st.sidebar.slider("Frame Skip (higher = faster, less thorough)", min_value=1, max_value=30, value=15)
|
| 23 |
+
batch_size = st.sidebar.slider("Batch Size (adjust for hardware)", min_value=1, max_value=16, value=12)
|
| 24 |
resize_enabled = st.sidebar.checkbox("Resize Frames (faster processing)", value=True)
|
| 25 |
resize_width = 640 if resize_enabled else None
|
| 26 |
+
quantize_model = st.sidebar.checkbox("Quantize Model (faster on CPU)", value=False)
|
| 27 |
|
| 28 |
# Load model and processor
|
| 29 |
@st.cache_resource
|
| 30 |
+
def load_model(quantize=quantize_model):
|
| 31 |
warnings.filterwarnings("ignore", message="Some weights of the model checkpoint.*were not used")
|
| 32 |
logging.set_verbosity_error()
|
| 33 |
|
| 34 |
+
try:
|
| 35 |
+
processor = YolosImageProcessor.from_pretrained("hustvl/yolos-tiny")
|
| 36 |
+
model = YolosForObjectDetection.from_pretrained("hustvl/yolos-tiny")
|
| 37 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 38 |
+
model.to(device)
|
| 39 |
+
|
| 40 |
+
# Apply dynamic quantization for CPU if enabled
|
| 41 |
+
if quantize and device.type == "cpu":
|
| 42 |
+
model = torch.quantization.quantize_dynamic(
|
| 43 |
+
model, {torch.nn.Linear}, dtype=torch.qint8
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
model.eval()
|
| 47 |
+
return processor, model, device
|
| 48 |
+
except Exception as e:
|
| 49 |
+
st.error(f"Failed to load model: {str(e)}. Please check your internet connection or clear the cache (~/.cache/huggingface/hub).")
|
| 50 |
+
raise
|
| 51 |
|
| 52 |
processor, model, device = load_model()
|
| 53 |
|
|
|
|
| 62 |
# Function to process a batch of frames
|
| 63 |
async def detect_faults_batch(frames, processor, model, device):
|
| 64 |
try:
|
|
|
|
| 65 |
frames = [resize_frame(frame, resize_width) for frame in frames]
|
| 66 |
inputs = processor(images=frames, return_tensors="pt").to(device)
|
| 67 |
with torch.no_grad():
|
|
|
|
| 98 |
annotated_frames.append(annotated_frame)
|
| 99 |
all_faults.append(faults)
|
| 100 |
|
|
|
|
| 101 |
if torch.cuda.is_available():
|
| 102 |
torch.cuda.empty_cache()
|
| 103 |
|
|
|
|
| 119 |
fps = int(cap.get(cv2.CAP_PROP_FPS))
|
| 120 |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 121 |
|
|
|
|
| 122 |
out_width = resize_width if resize_width else frame_width
|
| 123 |
out_height = int(out_width * frame_height / frame_width) if resize_width else frame_height
|
| 124 |
|
|
|
|
| 141 |
break
|
| 142 |
|
| 143 |
if frame_count % frame_skip != 0:
|
|
|
|
| 144 |
frame = resize_frame(frame, resize_width)
|
| 145 |
out.write(frame)
|
| 146 |
frame_count += 1
|