Saiky2k's picture
Update app.py
75b70fe verified
raw
history blame
16.5 kB
# app.py
import streamlit as st
from PIL import Image
import cv2
import numpy as np
import torch
import tempfile
import os
import requests
from io import BytesIO
# Cấu hình trang
st.set_page_config(page_title="Phát hiện người và độ sâu", layout="wide")
# Tạo module độ sâu đơn giản
class DepthEstimator:
def __init__(self):
self.model = None
self.processor = None
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def load_model(self):
if self.model is None:
from transformers import AutoImageProcessor, AutoModelForDepthEstimation
self.processor = AutoImageProcessor.from_pretrained("vinvino02/glpn-nyu")
self.model = AutoModelForDepthEstimation.from_pretrained("vinvino02/glpn-nyu")
self.model.to(self.device)
self.model.eval()
return self.model, self.processor
def predict_depth(self, image):
model, processor = self.load_model()
# Chuẩn bị đầu vào
if isinstance(image, np.ndarray):
# Chuyển từ OpenCV (BGR) sang RGB
if image.shape[2] == 3:
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(image)
else:
pil_image = image
inputs = processor(images=pil_image, return_tensors="pt").to(self.device)
# Dự đoán độ sâu
with torch.no_grad():
outputs = model(**inputs)
predicted_depth = outputs.predicted_depth
# Chuẩn hóa độ sâu để hiển thị tốt hơn
depth_min = torch.min(predicted_depth)
depth_max = torch.max(predicted_depth)
normalized_depth = (predicted_depth - depth_min) / (depth_max - depth_min)
normalized_depth = normalized_depth * 10 # Nhân với 10 để có giá trị mét hợp lý hơn
# Chuyển đổi sang mảng numpy
depth_map = normalized_depth.squeeze().cpu().numpy()
return depth_map
# Tải và cache mô hình YOLO
@st.cache_resource
def load_yolo_model():
from ultralytics import YOLO
model = YOLO("yolov8n.pt")
return model
# Phát hiện người trong ảnh
def detect_people(image, confidence_threshold=0.5):
yolo_model = load_yolo_model()
results = yolo_model(image, conf=confidence_threshold)
person_boxes = []
for result in results:
boxes = result.boxes.xyxy.cpu().numpy()
classes = result.boxes.cls.cpu().numpy()
confs = result.boxes.conf.cpu().numpy()
for box, cls, conf in zip(boxes, classes, confs):
if result.names[int(cls)] == "person" and conf > confidence_threshold:
x1, y1, x2, y2 = map(int, box[:4])
person_boxes.append((x1, y1, x2, y2, conf))
return person_boxes
# Xử lý ảnh
def process_image(image, confidence=0.5):
# Tạo bản sao của ảnh để vẽ lên
display_image = image.copy()
# Phát hiện người
person_boxes = detect_people(image, confidence)
# Ước tính độ sâu
depth_estimator = DepthEstimator()
depth_map = depth_estimator.predict_depth(image)
# Tạo bản đồ màu độ sâu
depth_colormap = create_depth_colormap(depth_map)
# Vẽ khung giới hạn và thông tin độ sâu
for x1, y1, x2, y2, conf in person_boxes:
# Vẽ khung giới hạn
cv2.rectangle(display_image, (x1, y1), (x2, y2), (0, 255, 0), 2)
# Tính toán độ sâu tại vị trí trung tâm
center_x = (x1 + x2) // 2
center_y = (y1 + y2) // 2
# Đảm bảo tọa độ nằm trong giới hạn
center_x = min(center_x, depth_map.shape[1] - 1) if center_x < depth_map.shape[1] else depth_map.shape[1] // 2
center_y = min(center_y, depth_map.shape[0] - 1) if center_y < depth_map.shape[0] else depth_map.shape[0] // 2
depth_value = depth_map[center_y, center_x]
# Vẽ nhãn độ sâu
text = f"Độ sâu: {depth_value:.2f}m ({conf:.2f})"
draw_label(display_image, text, (x1, y1))
return display_image, depth_colormap, len(person_boxes)
# Xử lý video
def process_video(video_path, confidence=0.5, progress_bar=None, progress_text=None):
# Mở video
cap = cv2.VideoCapture(video_path)
# Lấy thuộc tính video
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
# Tạo tệp đầu ra
temp_output_dir = tempfile.mkdtemp()
output_video_path = os.path.join(temp_output_dir, "detection_depth.mp4")
# Thiết lập writer
fourcc = cv2.VideoWriter_fourcc(*'XVID')
out = cv2.VideoWriter(output_video_path, fourcc, fps, (width * 2, height))
# Đối tượng phát hiện và ước tính độ sâu
depth_estimator = DepthEstimator()
# Biến đếm
frame_counter = 0
person_count = 0
# Tạo cột để hiển thị khung hình
preview_col1, preview_col2 = st.columns(2)
detection_placeholder = preview_col1.empty()
depth_placeholder = preview_col2.empty()
try:
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
frame_counter += 1
# Cập nhật tiến trình
if progress_bar:
progress = int(frame_counter / total_frames * 100)
progress_bar.progress(progress)
if frame_counter % 10 == 0 and progress_text:
progress_text.text(f"Đang xử lý: {frame_counter}/{total_frames} khung hình")
# Phát hiện người
person_boxes = detect_people(frame, confidence)
person_count += len(person_boxes)
# Ước tính độ sâu (chỉ xử lý mỗi 5 khung hình để tăng tốc độ)
if frame_counter % 5 == 0 or frame_counter == 1:
depth_map = depth_estimator.predict_depth(frame)
# Tạo bản đồ màu độ sâu
depth_colormap = create_depth_colormap(depth_map)
# Vẽ khung giới hạn và thông tin độ sâu
for x1, y1, x2, y2, conf in person_boxes:
# Vẽ khung giới hạn
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
# Tính toán độ sâu tại vị trí trung tâm
center_x = (x1 + x2) // 2
center_y = (y1 + y2) // 2
# Đảm bảo tọa độ nằm trong giới hạn
center_x = min(center_x, depth_map.shape[1] - 1) if center_x < depth_map.shape[1] else depth_map.shape[1] // 2
center_y = min(center_y, depth_map.shape[0] - 1) if center_y < depth_map.shape[0] else depth_map.shape[0] // 2
depth_value = depth_map[center_y, center_x]
# Vẽ nhãn độ sâu
text = f"Độ sâu: {depth_value:.2f}m ({conf:.2f})"
draw_label(frame, text, (x1, y1))
# Ghép hai khung hình lại với nhau
combined_frame = np.hstack((frame, cv2.cvtColor(depth_colormap, cv2.COLOR_RGB2BGR)))
# Ghi khung hình
out.write(combined_frame)
# Hiển thị khung hình trong Streamlit (cập nhật mỗi 5 khung hình)
if frame_counter % 5 == 0:
detection_placeholder.image(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB), caption="Phát hiện người", use_column_width=True)
depth_placeholder.image(depth_colormap, caption="Bản đồ độ sâu", use_column_width=True)
finally:
# Giải phóng tài nguyên
cap.release()
out.release()
# Tính trung bình số người phát hiện được
avg_persons = person_count / frame_counter if frame_counter > 0 else 0
return output_video_path, avg_persons
# Hàm tiện ích
def create_depth_colormap(depth_map):
# Chuẩn hóa độ sâu từ 0-1
normalized = (depth_map - np.min(depth_map)) / (np.max(depth_map) - np.min(depth_map))
# Đảo ngược (gần = màu ấm, xa = màu lạnh)
inv_depth = 1 - normalized
# Chuyển đổi sang bản đồ màu
colored = cv2.applyColorMap((inv_depth * 255).astype(np.uint8), cv2.COLORMAP_TURBO)
# Chuyển đổi từ BGR sang RGB
return cv2.cvtColor(colored, cv2.COLOR_BGR2RGB)
def draw_label(image, text, position):
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 0.7
font_thickness = 2
text_size = cv2.getTextSize(text, font, font_scale, font_thickness)[0]
x, y = position
text_x = x
text_y = y - 10
rect_x1 = text_x - 5
rect_y1 = text_y - text_size[1] - 5
rect_x2 = text_x + text_size[0] + 5
rect_y2 = text_y + 5
cv2.rectangle(image, (rect_x1, rect_y1), (rect_x2, rect_y2), (0, 255, 0), -1)
cv2.putText(image, text, (text_x, text_y), font, font_scale, (0, 0, 0), font_thickness)
# Giao diện người dùng chính
def main():
st.title("Phát hiện người và Ước tính độ sâu")
# Sidebar với tùy chọn
st.sidebar.header("Tùy chọn")
confidence = st.sidebar.slider("Ngưỡng tin cậy", 0.0, 1.0, 0.5)
# Chọn chế độ: Ảnh hoặc Video
mode = st.sidebar.radio("Chế độ", ["Ảnh", "Video"])
# Chọn nguồn: Tải lên hoặc Mẫu
source = st.sidebar.radio("Nguồn", ["Tải lên", "Mẫu"])
if mode == "Ảnh":
if source == "Tải lên":
uploaded_file = st.file_uploader("Tải lên ảnh", type=['jpg', 'jpeg', 'png'])
if uploaded_file is not None:
image = Image.open(uploaded_file)
image = np.array(image)
# Chuyển đổi sang RGB nếu là RGBA
if image.shape[2] == 4:
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
st.image(image, caption="Ảnh đã tải lên", use_column_width=True)
if st.button("Xử lý Ảnh"):
with st.spinner("Đang xử lý ảnh..."):
result_image, depth_colormap, person_count = process_image(image, confidence)
st.success(f"Phát hiện {person_count} người trong ảnh")
col1, col2 = st.columns(2)
col1.image(result_image, caption="Kết quả phát hiện", use_column_width=True)
col2.image(depth_colormap, caption="Bản đồ độ sâu", use_column_width=True)
else:
# Sử dụng ảnh mẫu
st.info("Đang sử dụng ảnh mẫu...")
sample_img_url = "https://storage.googleapis.com/sfr-vision-language-research/DINO/ground_truth_images/000000014439.jpg"
try:
response = requests.get(sample_img_url)
image = Image.open(BytesIO(response.content))
image = np.array(image)
st.image(image, caption="Ảnh mẫu", use_column_width=True)
if st.button("Xử lý Ảnh"):
with st.spinner("Đang xử lý ảnh..."):
result_image, depth_colormap, person_count = process_image(image, confidence)
st.success(f"Phát hiện {person_count} người trong ảnh")
col1, col2 = st.columns(2)
col1.image(result_image, caption="Kết quả phát hiện", use_column_width=True)
col2.image(depth_colormap, caption="Bản đồ độ sâu", use_column_width=True)
except Exception as e:
st.error(f"Không thể tải ảnh mẫu: {e}")
else:
# Chế độ Video
if source == "Tải lên":
uploaded_file = st.file_uploader("Tải lên video", type=['mp4', 'avi', 'mov'])
if uploaded_file is not None:
# Lưu tệp đã tải lên vào thư mục tạm thời
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
temp_file.write(uploaded_file.read())
video_path = temp_file.name
temp_file.close()
st.video(video_path)
if st.button("Xử lý Video"):
progress_bar = st.progress(0)
progress_text = st.empty()
with st.spinner("Đang xử lý video..."):
output_path, avg_persons = process_video(video_path, confidence, progress_bar, progress_text)
st.success(f"Xử lý video hoàn tất! Trung bình phát hiện {avg_persons:.1f} người/khung hình")
st.video(output_path)
# Nút tải xuống
with open(output_path, 'rb') as file:
st.download_button(
label="Tải xuống video kết quả",
data=file,
file_name="detection_depth_result.mp4",
mime="video/mp4"
)
# Xóa tệp tạm thời
os.unlink(video_path)
else:
# Sử dụng video mẫu
st.info("Đang sử dụng video mẫu...")
sample_video_url = "https://huggingface.co/spaces/Nupoor/SampleVideoDataset/resolve/main/pexels-richard-de-souza-1635985.mp4"
try:
response = requests.get(sample_video_url)
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
temp_file.write(response.content)
video_path = temp_file.name
temp_file.close()
st.video(video_path)
if st.button("Xử lý Video"):
progress_bar = st.progress(0)
progress_text = st.empty()
with st.spinner("Đang xử lý video..."):
output_path, avg_persons = process_video(video_path, confidence, progress_bar, progress_text)
st.success(f"Xử lý video hoàn tất! Trung bình phát hiện {avg_persons:.1f} người/khung hình")
st.video(output_path)
# Nút tải xuống
with open(output_path, 'rb') as file:
st.download_button(
label="Tải xuống video kết quả",
data=file,
file_name="detection_depth_result.mp4",
mime="video/mp4"
)
# Xóa tệp tạm thời
os.unlink(video_path)
except Exception as e:
st.error(f"Không thể tải video mẫu: {e}")
# Thông tin
st.sidebar.header("Thông tin")
st.sidebar.markdown("""
**Mô hình sử dụng:**
- Phát hiện người: YOLOv8n
- Ước tính độ sâu: GLPN-NYU
**Cách sử dụng:**
1. Chọn chế độ (Ảnh/Video)
2. Chọn nguồn (Tải lên/Mẫu)
3. Điều chỉnh ngưỡng tin cậy
4. Nhấn nút xử lý
""")
# Thiết lập requirements.txt
def create_requirements():
return """
streamlit==1.30.0
numpy==1.24.3
Pillow==10.0.0
opencv-python-headless==4.8.0.76
torch==2.0.1
torchvision==0.15.2
transformers==4.35.2
ultralytics==8.0.43
requests==2.31.0
"""
if __name__ == "__main__":
main()