import os import cv2 import numpy as np import gradio as gr import torch from mouse_tracker import MouseTrackerAnalyzer from huggingface_hub import hf_hub_download # 检查是否在Hugging Face Spaces环境中 try: import spaces is_spaces = True print("检测到 Hugging Face Spaces 环境") except ImportError: is_spaces = False print("在本地环境运行") # 全局配置 model_base_name = "fst-v1.3-n" # 模型基础名称,无后缀 total_frames = 0 # 根据后缀构造模型路径 def get_model_file_path(model_suffix): return f"./{model_base_name}{model_suffix}" # 从视频中提取特定帧 def extract_frame(video_path, frame_num): if not video_path: return None cap = cv2.VideoCapture(video_path) if not cap.isOpened(): return None cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num) ret, frame = cap.read() cap.release() if not ret: return None return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # 选择视频文件 def select_video(video_file, model_suffix): global total_frames if not video_file: return None, "请选择视频文件", gr.Slider(0,0,0), gr.Slider(0,0,0) total_frames = int(cv2.VideoCapture(video_file).get(cv2.CAP_PROP_FRAME_COUNT)) # 读取首帧 cap = cv2.VideoCapture(video_file) ret, frame = cap.read() cap.release() if not ret: return None, "无法读取视频帧", gr.Slider(0,0,0), gr.Slider(0,0,0) frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # 更新滑块 start = gr.Slider(minimum=0, maximum=total_frames-1, value=0, step=1) end = gr.Slider(minimum=0, maximum=total_frames-1, value=total_frames-1, step=1) status = f"视频加载成功,总帧数: {total_frames}. 使用模型: {os.path.basename(get_model_file_path(model_suffix))}" return frame_rgb, status, start, end # 预览帧 def preview_frame(video_file, frame_num): if not video_file: return None, "请先选择视频文件" frame = extract_frame(video_file, frame_num) if frame is None: return None, "无法读取指定帧" return frame, f"帧 {frame_num}" # 分析实现 def _start_analysis_impl(video, model_suffix, conf, iou, max_det, start_frame, end_frame, threshold): if not video: return None, None, "请选择视频文件" if start_frame >= end_frame: return None, None, "起始帧必须小于结束帧" # 构造路径 video_name = os.path.splitext(os.path.basename(video))[0] output_path = os.path.join(os.path.dirname(video), f"{video_name}_out.mp4") csv_path = os.path.join(os.path.dirname(video), f"{video_name}_results.csv") device = 'cuda' if torch.cuda.is_available() else 'cpu' model_path = get_model_file_path(model_suffix) if not os.path.exists(model_path): if is_spaces: try: model_path = hf_hub_download( repo_id="YOUR_HF_USERNAME/YOUR_REPO_NAME", filename=f"weights/{model_base_name}{model_suffix}" ) except Exception: print(f"下载模型失败: {model_path}") else: print(f"警告: 本地未找到模型文件 {model_path}") # 初始化分析器 analyzer = MouseTrackerAnalyzer( model_path=model_path, conf=conf, iou=iou, max_det=max_det, verbose=True ) analyzer.struggle_threshold = threshold # 运行分析 analyzer.process_video( video_path=video, output_path=output_path, start_frame=start_frame, end_frame=end_frame, callback=lambda prog, frm, res: print(f"进度: {prog}% 检测: {len(res)} 项") ) analyzer.save_results(csv_path) # 生成图表 plot_path = None if analyzer.results: plot_path = analyzer.generate_time_series_plot() status = f"分析完成。视频: {output_path}, CSV: {csv_path}" if plot_path: status += f", 图表: {plot_path}" return output_path, plot_path, status # HF Spaces GPU 装饰 if is_spaces: @spaces.GPU(duration=120) def start_analysis(video, model_suffix, conf, iou, max_det, start_frame, end_frame, threshold): return _start_analysis_impl(video, model_suffix, conf, iou, max_det, start_frame, end_frame, threshold) else: def start_analysis(video, model_suffix, conf, iou, max_det, start_frame, end_frame, threshold): return _start_analysis_impl(video, model_suffix, conf, iou, max_det, start_frame, end_frame, threshold) # 创建 Gradio 界面 def create_interface(): with gr.Blocks(title="鼠强迫游泳挣扎度分析") as app: gr.Markdown("# 鼠强迫游泳测试挣扎度分析 (对象跟踪)") with gr.Row(): with gr.Column(scale=1): video_input = gr.Video(label="输入视频") model_format = gr.Dropdown( label="模型格式", choices=[".onnx", ".engine", ".pt", ".mlpackage"], value=".onnx", interactive=True ) device_info = gr.Textbox( label="系统信息", value=f"设备: {'GPU' if torch.cuda.is_available() else 'CPU'}", interactive=False ) conf = gr.Slider(0.1, 0.9, value=0.25, step=0.05, label="置信度阈值") iou = gr.Slider(0.1, 0.9, value=0.45, step=0.05, label="IoU阈值") max_det = gr.Slider(1, 50, value=20, step=1, label="最大检测数") threshold = gr.Slider(0, 1, value=0.3, step=0.01, label="挣扎阈值") start_frame = gr.Slider(0, 999999, value=0, step=1, label="起始帧") end_frame = gr.Slider(0, 999999, value=999999, step=1, label="结束帧") preview_btn = gr.Button("预览帧") start_btn = gr.Button("开始分析", variant="primary") with gr.Column(scale=2): with gr.Tab("预览"): preview_image = gr.Image(label="预览图像", type="numpy", height=400) status_text = gr.Textbox(label="状态", interactive=False) with gr.Tab("结果"): output_video = gr.Video(label="分析结果视频") result_plot = gr.Image(label="挣扎分数时间序列") result_status = gr.Textbox(label="分析状态", interactive=False) # 事件绑定,包含模型格式参数 video_input.change(select_video, inputs=[video_input, model_format], outputs=[preview_image, status_text, start_frame, end_frame]) preview_btn.click(preview_frame, inputs=[video_input, start_frame], outputs=[preview_image, status_text]) start_btn.click( start_analysis, inputs=[video_input, model_format, conf, iou, max_det, start_frame, end_frame, threshold], outputs=[output_video, result_plot, result_status] ) return app if __name__ == "__main__": # 清理代理 for key in ['http_proxy', 'https_proxy', 'all_proxy']: os.environ.pop(key, None) print(f"设备: {'GPU' if torch.cuda.is_available() else 'CPU'}") print(f"默认模型路径: {get_model_file_path('.onnx')}") app = create_interface() if is_spaces: app.launch() else: app.launch(server_name="0.0.0.0", server_port=7860, share=False)