Spaces:
Sleeping
Sleeping
| import os | |
| import gradio as gr | |
| from ultralytics import YOLO | |
| from fastapi import FastAPI | |
| from PIL import Image | |
| import torch | |
| import spaces | |
| import numpy as np | |
| import cv2 | |
| from pathlib import Path | |
| import tempfile | |
| from tqdm import tqdm | |
| # 从环境变量获取密码 | |
| APP_USERNAME = "admin" # 用户名保持固定 | |
| APP_PASSWORD = os.getenv("APP_PASSWORD", "default_password") # 从环境变量获取密码 | |
| app = FastAPI() | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| print(f"使用设备: {device}") | |
| model = YOLO('kunin-mice-pose.v0.1.5n.pt') | |
| print("模型加载完成") | |
| # 定义认证状态 | |
| class AuthState: | |
| def __init__(self): | |
| self.is_logged_in = False | |
| auth_state = AuthState() | |
| def login(username, password): | |
| """登录验证""" | |
| if username == APP_USERNAME and password == APP_PASSWORD: | |
| auth_state.is_logged_in = True | |
| return gr.update(visible=False), gr.update(visible=True), "登录成功" | |
| return gr.update(visible=True), gr.update(visible=False), "用户名或密码错误" | |
| def process_video(video_path, process_seconds=20, conf_threshold=0.2, max_det=8): | |
| """ | |
| 处理视频并进行小鼠检测 | |
| """ | |
| print("开始处理视频...") | |
| if not auth_state.is_logged_in: | |
| return None, "请先登录" | |
| print("创建临时输出文件...") | |
| with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmp_file: | |
| output_path = tmp_file.name | |
| print("读取视频信息...") | |
| cap = cv2.VideoCapture(video_path) | |
| fps = int(cap.get(cv2.CAP_PROP_FPS)) | |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| total_frames = int(process_seconds * fps) if process_seconds else int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| cap.release() | |
| print(f"视频信息: {width}x{height} @ {fps}fps, 总帧数: {total_frames}") | |
| print("初始化视频写入器...") | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| video_writer = cv2.VideoWriter( | |
| output_path, | |
| fourcc, | |
| fps, | |
| (width, height) | |
| ) | |
| base_size = min(width, height) | |
| line_thickness = max(1, int(base_size * 0.002)) | |
| print("开始YOLO推理...") | |
| results = model.predict( | |
| source=video_path, | |
| device=device, | |
| conf=conf_threshold, | |
| save=False, | |
| show=False, | |
| stream=True, | |
| line_width=line_thickness, | |
| boxes=True, | |
| show_labels=True, | |
| show_conf=True, | |
| vid_stride=1, | |
| max_det=max_det, | |
| retina_masks=True, | |
| verbose=False | |
| ) | |
| frame_count = 0 | |
| detection_info = [] | |
| all_positions = [] | |
| heatmap = np.zeros((height, width), dtype=np.float32) | |
| print("处理检测结果...") | |
| progress_bar = tqdm(total=total_frames, desc="处理帧") | |
| for r in results: | |
| frame = r.plot() | |
| if hasattr(r, 'keypoints') and r.keypoints is not None: | |
| kpts = r.keypoints.data | |
| if isinstance(kpts, torch.Tensor): | |
| kpts = kpts.cpu().numpy() | |
| if kpts.shape == (1, 8, 3): | |
| x, y = int(kpts[0, 0, 0]), int(kpts[0, 0, 1]) | |
| all_positions.append([x, y]) | |
| if 0 <= x < width and 0 <= y < height: | |
| sigma = 10 | |
| kernel_size = 31 | |
| temp_heatmap = np.zeros((height, width), dtype=np.float32) | |
| temp_heatmap[y, x] = 1 | |
| temp_heatmap = cv2.GaussianBlur(temp_heatmap, (kernel_size, kernel_size), sigma) | |
| heatmap += temp_heatmap | |
| frame_info = { | |
| "frame": frame_count + 1, | |
| "count": len(r.boxes), | |
| "detections": [] | |
| } | |
| for box in r.boxes: | |
| conf = float(box.conf[0]) | |
| cls = int(box.cls[0]) | |
| cls_name = r.names[cls] | |
| frame_info["detections"].append({ | |
| "class": cls_name, | |
| "confidence": f"{conf:.2%}" | |
| }) | |
| detection_info.append(frame_info) | |
| video_writer.write(frame) | |
| frame_count += 1 | |
| progress_bar.update(1) | |
| if process_seconds and frame_count >= total_frames: | |
| break | |
| progress_bar.close() | |
| print("视频处理完成") | |
| video_writer.release() | |
| print("生成分析报告...") | |
| confidences = [float(det['confidence'].strip('%'))/100 for info in detection_info for det in info['detections']] | |
| hist, bins = np.histogram(confidences, bins=5) | |
| confidence_report = "\n".join([ | |
| f"置信度 {bins[i]:.2f}-{bins[i+1]:.2f}: {hist[i]:3d}个检测 ({hist[i]/len(confidences)*100:.1f}%)" | |
| for i in range(len(hist)) | |
| ]) | |
| report = f"""视频分析报告: | |
| 参数设置: | |
| - 置信度阈值: {conf_threshold:.2f} | |
| - 最大检测数量: {max_det} | |
| - 处理时长: {process_seconds}秒 | |
| 分析结果: | |
| - 处理帧数: {frame_count} | |
| - 平均每帧检测到的老鼠数: {np.mean([info['count'] for info in detection_info]):.1f} | |
| - 最大检测数: {max([info['count'] for info in detection_info])} | |
| - 最小检测数: {min([info['count'] for info in detection_info])} | |
| 置信度分布: | |
| {confidence_report} | |
| """ | |
| def filter_trajectories(positions, width, height, max_jump_distance=100): | |
| if len(positions) < 3: | |
| return positions | |
| filtered_positions = [] | |
| last_valid_pos = None | |
| for i, pos in enumerate(positions): | |
| x, y = pos | |
| if not (0 <= x < width and 0 <= y < height): | |
| continue | |
| if last_valid_pos is None: | |
| filtered_positions.append(pos) | |
| last_valid_pos = pos | |
| continue | |
| distance = np.sqrt((x - last_valid_pos[0])**2 + (y - last_valid_pos[1])**2) | |
| if distance > max_jump_distance: | |
| if len(filtered_positions) > 0: | |
| next_valid_pos = None | |
| for next_pos in positions[i:]: | |
| nx, ny = next_pos | |
| if (0 <= nx < width and 0 <= ny < height): | |
| next_distance = np.sqrt((nx - last_valid_pos[0])**2 + (ny - last_valid_pos[1])**2) | |
| if next_distance <= max_jump_distance: | |
| next_valid_pos = next_pos | |
| break | |
| if next_valid_pos is not None: | |
| steps = max(2, int(distance / max_jump_distance)) | |
| for j in range(1, steps): | |
| alpha = j / steps | |
| interp_x = int(last_valid_pos[0] * (1 - alpha) + next_valid_pos[0] * alpha) | |
| interp_y = int(last_valid_pos[1] * (1 - alpha) + next_valid_pos[1] * alpha) | |
| filtered_positions.append([interp_x, interp_y]) | |
| filtered_positions.append(next_valid_pos) | |
| last_valid_pos = next_valid_pos | |
| else: | |
| filtered_positions.append(pos) | |
| last_valid_pos = pos | |
| window_size = 5 | |
| smoothed_positions = [] | |
| if len(filtered_positions) >= window_size: | |
| smoothed_positions.extend(filtered_positions[:window_size//2]) | |
| for i in range(window_size//2, len(filtered_positions) - window_size//2): | |
| window = filtered_positions[i-window_size//2:i+window_size//2+1] | |
| smoothed_x = int(np.mean([p[0] for p in window])) | |
| smoothed_y = int(np.mean([p[1] for p in window])) | |
| smoothed_positions.append([smoothed_x, smoothed_y]) | |
| smoothed_positions.extend(filtered_positions[-window_size//2:]) | |
| else: | |
| smoothed_positions = filtered_positions | |
| return smoothed_positions | |
| print("生成轨迹图...") | |
| trajectory_img = np.zeros((height, width, 3), dtype=np.uint8) + 255 | |
| points = np.array(all_positions, dtype=np.int32) | |
| if len(points) > 1: | |
| filtered_points = filter_trajectories(points.tolist(), width, height) | |
| points = np.array(filtered_points, dtype=np.int32) | |
| for i in range(len(points) - 1): | |
| ratio = i / (len(points) - 1) | |
| color = ( | |
| int((1 - ratio) * 255), | |
| 50, | |
| int(ratio * 255) | |
| ) | |
| cv2.line(trajectory_img, tuple(points[i]), tuple(points[i + 1]), color, 2) | |
| cv2.circle(trajectory_img, tuple(points[0]), 8, (0, 255, 0), -1) | |
| cv2.circle(trajectory_img, tuple(points[-1]), 8, (0, 0, 255), -1) | |
| arrow_interval = max(len(points) // 20, 1) | |
| for i in range(0, len(points) - arrow_interval, arrow_interval): | |
| pt1 = tuple(points[i]) | |
| pt2 = tuple(points[i + arrow_interval]) | |
| angle = np.arctan2(pt2[1] - pt1[1], pt2[0] - pt1[0]) | |
| cv2.arrowedLine(trajectory_img, pt1, pt2, (100, 100, 100), 1, tipLength=0.2) | |
| print("生成热力图...") | |
| if np.max(heatmap) > 0: | |
| heatmap_normalized = cv2.normalize(heatmap, None, 0, 255, cv2.NORM_MINMAX) | |
| heatmap_colored = cv2.applyColorMap(heatmap_normalized.astype(np.uint8), cv2.COLORMAP_JET) | |
| alpha = 0.7 | |
| heatmap_colored = cv2.addWeighted(heatmap_colored, alpha, np.full_like(heatmap_colored, 255), 1-alpha, 0) | |
| print("保存结果图像...") | |
| trajectory_path = output_path.replace('.mp4', '_trajectory.png') | |
| heatmap_path = output_path.replace('.mp4', '_heatmap.png') | |
| cv2.imwrite(trajectory_path, trajectory_img) | |
| cv2.imwrite(heatmap_path, heatmap_colored) | |
| print("处理完成!") | |
| return output_path, trajectory_path, heatmap_path, report | |
| # 创建 Gradio 界面 | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# 🐭 小鼠行为分析 (Mice Behavior Analysis)") | |
| with gr.Group() as login_interface: | |
| username = gr.Textbox(label="用户名") | |
| password = gr.Textbox(label="密码", type="password") | |
| login_button = gr.Button("登录") | |
| login_msg = gr.Textbox(label="消息", interactive=False) | |
| with gr.Group(visible=False) as main_interface: | |
| gr.Markdown("上传视频来检测和分析小鼠行为 | Upload a video to detect and analyze mice behavior") | |
| with gr.Row(): | |
| with gr.Column(): | |
| video_input = gr.Video(label="输入视频") | |
| process_seconds = gr.Number( | |
| label="处理时长(秒,0表示处理整个视频)", | |
| value=20 | |
| ) | |
| conf_threshold = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.2, | |
| step=0.05, | |
| label="置信度阈值", | |
| info="越高越严格,建议范围0.2-0.5" | |
| ) | |
| max_det = gr.Slider( | |
| minimum=1, | |
| maximum=10, | |
| value=1, | |
| step=1, | |
| label="最大检测数量", | |
| info="每帧最多检测的目标数量" | |
| ) | |
| process_btn = gr.Button("开始处理") | |
| with gr.Column(): | |
| video_output = gr.Video(label="检测结果") | |
| with gr.Row(): | |
| trajectory_output = gr.Image(label="运动轨迹") | |
| heatmap_output = gr.Image(label="热力图") | |
| report_output = gr.Textbox(label="分析报告") | |
| gr.Markdown(""" | |
| ### 使用说明 | |
| 1. 上传视频文件 | |
| 2. 设置处理参数: | |
| - 处理时长:需要分析的视频时长(秒) | |
| - 置信度阈值:检测的置信度要求(越高越严格) | |
| - 最大检测数量:每帧最多检测的目标数量 | |
| 3. 等待处理完成 | |
| 4. 查看检测结果视频和分析报告 | |
| ### 注意事项 | |
| - 支持常见视频格式(mp4, avi 等) | |
| - 建议视频分辨率不超过 1920x1080 | |
| - 处理时间与视频长度和分辨率相关 | |
| - 置信度建议范围:0.2-0.5 | |
| - 最大检测数量建议根据实际场景设置 | |
| """) | |
| login_button.click( | |
| fn=login, | |
| inputs=[username, password], | |
| outputs=[login_interface, main_interface, login_msg] | |
| ) | |
| process_btn.click( | |
| fn=process_video, | |
| inputs=[video_input, process_seconds, conf_threshold, max_det], | |
| outputs=[video_output, trajectory_output, heatmap_output, report_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |