| import os |
| import cv2 |
| import numpy as np |
| import gradio as gr |
| import tempfile |
| import torch |
| from mouse_tracker import MouseTrackerAnalyzer |
| import huggingface_hub |
| from huggingface_hub import hf_hub_download |
|
|
| |
| try: |
| import spaces |
| is_spaces = True |
| print("检测到Hugging Face Spaces环境") |
| except ImportError: |
| is_spaces = False |
| print("在本地环境运行") |
|
|
| |
| analyzer = None |
| video_file_path = None |
| model_file_path = "./fst-v1.2-n.engine" |
| total_frames = 0 |
| output_path = None |
|
|
| |
| def extract_frame(video_path, frame_num): |
| """从视频中提取特定帧""" |
| if not video_path: |
| return None |
| |
| cap = cv2.VideoCapture(video_path) |
| if not cap.isOpened(): |
| return None |
| |
| |
| cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num) |
| |
| |
| ret, frame = cap.read() |
| cap.release() |
| |
| if not ret: |
| return None |
| |
| |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| return frame_rgb |
|
|
| |
| def select_video(video_file): |
| global video_file_path, total_frames |
| |
| if not video_file: |
| return None, "请选择视频文件", gr.Slider(minimum=0, maximum=0, value=0), gr.Slider(minimum=0, maximum=0, value=0) |
| |
| video_file_path = video_file |
| |
| |
| cap = cv2.VideoCapture(video_file_path) |
| if not cap.isOpened(): |
| return None, "无法打开视频文件", gr.Slider(minimum=0, maximum=0, value=0), gr.Slider(minimum=0, maximum=0, value=0) |
| |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
| |
| |
| ret, first_frame = cap.read() |
| cap.release() |
| |
| if not ret: |
| return None, "无法读取视频帧", gr.Slider(minimum=0, maximum=0, value=0), gr.Slider(minimum=0, maximum=0, value=0) |
| |
| |
| first_frame_rgb = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB) |
| |
| |
| start_slider = gr.Slider(minimum=0, maximum=total_frames-1, value=0, step=1) |
| end_slider = gr.Slider(minimum=0, maximum=total_frames-1, value=total_frames-1, step=1) |
| |
| model_status = f"使用模型: {os.path.basename(model_file_path)}" |
| return first_frame_rgb, f"视频加载成功,总帧数: {total_frames}. {model_status}", start_slider, end_slider |
|
|
| |
| def preview_frame(video_file, frame_num): |
| if not video_file: |
| return None, "请先选择视频文件" |
| |
| |
| frame = extract_frame(video_file, frame_num) |
| if frame is None: |
| return None, "无法读取指定帧" |
| |
| return frame, f"帧 {frame_num}" |
|
|
| |
| |
| if is_spaces: |
| @spaces.GPU(duration=120) |
| def start_analysis(video, conf, iou, max_det, start_frame, end_frame, threshold): |
| return _start_analysis_impl(video, conf, iou, max_det, start_frame, end_frame, threshold) |
| else: |
| def start_analysis(video, conf, iou, max_det, start_frame, end_frame, threshold): |
| return _start_analysis_impl(video, conf, iou, max_det, start_frame, end_frame, threshold) |
|
|
| |
| def _start_analysis_impl(video, conf, iou, max_det, start_frame, end_frame, threshold): |
| global analyzer, output_path, model_file_path |
| |
| if not video: |
| return None, None, "请选择视频文件" |
| |
| if start_frame >= end_frame: |
| return None, None, "起始帧必须小于结束帧" |
| |
| |
| video_name = os.path.splitext(os.path.basename(video))[0] |
| output_path = os.path.join(os.path.dirname(video), f"{video_name}_out.mp4") |
| csv_path = os.path.join(os.path.dirname(video), f"{video_name}_results.csv") |
| |
| try: |
| |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| print(f"使用设备: {device}") |
| |
| |
| if not os.path.exists(model_file_path): |
| |
| if is_spaces: |
| try: |
| print(f"尝试从Hugging Face Hub下载模型: {os.path.basename(model_file_path)}") |
| model_file_path = hf_hub_download( |
| repo_id="YOUR_HF_USERNAME/YOUR_REPO_NAME", |
| filename="weights/fst-v1.2-n.onnx" |
| ) |
| print(f"模型已下载到: {model_file_path}") |
| except Exception as e: |
| print(f"从Hub下载模型失败: {str(e)}") |
| else: |
| print(f"警告: 模型文件 {model_file_path} 不存在!") |
| |
| |
| analyzer = MouseTrackerAnalyzer( |
| model_path=model_file_path, |
| conf=conf, |
| iou=iou, |
| max_det=max_det, |
| verbose=True |
| ) |
| analyzer.struggle_threshold = threshold |
| |
| |
| def progress_update(progress, frame, results): |
| print(f"处理进度: {progress}%, 检测到对象数: {len(results)}") |
| |
| print(f"处理视频: {video}") |
| print(f"输出路径: {output_path}") |
| print(f"参数: conf={conf}, iou={iou}, max_det={max_det}, threshold={threshold}") |
| print(f"使用模型: {model_file_path}") |
| |
| |
| results = analyzer.process_video( |
| video_path=video, |
| output_path=output_path, |
| start_frame=start_frame, |
| end_frame=end_frame, |
| callback=progress_update |
| ) |
| |
| |
| print(f"保存结果到CSV: {csv_path}") |
| analyzer.save_results(csv_path) |
| print(f"结果已保存到CSV,共 {len(analyzer.results)} 帧数据") |
| |
| |
| print("生成时间序列图...") |
| if len(analyzer.results) == 0: |
| print("警告: 没有可用于绘图的结果!") |
| plot_path = None |
| else: |
| plot_path = analyzer.generate_time_series_plot() |
| if plot_path and os.path.exists(plot_path): |
| print(f"图表已生成并保存到: {plot_path}, 大小: {os.path.getsize(plot_path)/1024:.2f}KB") |
| else: |
| print(f"生成图表失败或图表文件不存在!") |
| plot_path = None |
| |
| |
| if os.path.exists(output_path): |
| file_size = os.path.getsize(output_path) / (1024 * 1024) |
| print(f"输出视频大小: {file_size:.2f}MB") |
| |
| |
| debug_frame_path = os.path.join(os.path.dirname(output_path), "debug_frame.jpg") |
| if os.path.exists(debug_frame_path): |
| print(f"调试帧保存在: {debug_frame_path}") |
| |
| if plot_path and os.path.exists(plot_path): |
| print(f"图表文件存在于: {plot_path}, 大小: {os.path.getsize(plot_path)/1024:.2f}KB") |
| |
| |
| status_message = "分析完成。" |
| |
| if os.path.exists(output_path): |
| status_message += f"视频已保存。" |
| else: |
| status_message += "警告: 未找到输出视频。" |
| |
| if plot_path and os.path.exists(plot_path): |
| status_message += f" 时间序列图已生成。" |
| else: |
| status_message += " 警告: 生成时间序列图失败。" |
| |
| status_message += f" 结果已保存到: {csv_path}" |
| |
| return output_path, plot_path, status_message |
| except Exception as e: |
| import traceback |
| traceback.print_exc() |
| return None, None, f"处理错误: {str(e)}" |
|
|
| |
| def create_interface(): |
| with gr.Blocks(title="鼠强迫游泳挣扎度分析 - 对象跟踪") as app: |
| gr.Markdown("# 鼠强迫游泳测试挣扎度分析 (对象跟踪)") |
| |
| with gr.Row(): |
| with gr.Column(scale=1): |
| |
| video_input = gr.Video(label="输入视频") |
| |
| |
| device_info = "GPU" if torch.cuda.is_available() else "CPU" |
| model_info = gr.Textbox( |
| label="系统信息", |
| value=f"使用模型: {os.path.basename(model_file_path)} | 计算设备: {device_info}", |
| interactive=False |
| ) |
| |
| |
| with gr.Row(): |
| conf = gr.Slider(minimum=0.1, maximum=0.9, value=0.25, step=0.05, label="置信度阈值") |
| iou = gr.Slider(minimum=0.1, maximum=0.9, value=0.45, step=0.05, label="IoU阈值") |
| |
| with gr.Row(): |
| max_det = gr.Slider(minimum=1, maximum=50, value=20, step=1, label="最大检测数") |
| threshold = gr.Slider(minimum=0, maximum=1, value=0.3, step=0.01, label="挣扎阈值") |
| |
| |
| start_frame = gr.Slider(minimum=0, maximum=999999, value=0, step=1, label="起始帧") |
| end_frame = gr.Slider(minimum=0, maximum=999999, value=999999, step=1, label="结束帧") |
| |
| |
| preview_btn = gr.Button("预览帧") |
| |
| |
| start_btn = gr.Button("开始分析", variant="primary") |
| |
| with gr.Column(scale=2): |
| |
| with gr.Tab("预览"): |
| |
| preview_image = gr.Image(label="预览图像", type="numpy", height=400) |
| status_text = gr.Textbox(label="状态", interactive=False) |
| gr.Markdown(""" |
| ### 使用说明: |
| 1. 选择一个视频文件 |
| 2. 调整参数 |
| - 置信度阈值: 对象检测的最低置信度,较低的值会检测更多潜在对象 |
| - IoU阈值: 用于过滤重叠检测 |
| - 最大检测数: 每帧检测的最大对象数 |
| - 挣扎阈值: 分类为挣扎状态的最低分数 |
| 3. 设置帧范围 |
| 4. 点击"开始分析"按钮 |
| |
| 系统将自动跟踪小鼠并分析其挣扎行为,无需手动定义区域 |
| """) |
| |
| with gr.Tab("结果"): |
| with gr.Row(): |
| output_video = gr.Video(label="分析结果视频") |
| result_plot = gr.Image(label="挣扎分数时间序列") |
| |
| result_status = gr.Textbox(label="分析状态", interactive=False) |
| |
| |
| video_input.change(select_video, inputs=[video_input], outputs=[preview_image, status_text, start_frame, end_frame]) |
| |
| preview_btn.click(preview_frame, inputs=[video_input, start_frame], outputs=[preview_image, status_text]) |
| |
| start_btn.click( |
| start_analysis, |
| inputs=[video_input, conf, iou, max_det, start_frame, end_frame, threshold], |
| outputs=[output_video, result_plot, result_status] |
| ) |
| |
| return app |
|
|
| |
| if __name__ == "__main__": |
| |
| if 'http_proxy' in os.environ: |
| del os.environ['http_proxy'] |
| if 'https_proxy' in os.environ: |
| del os.environ['https_proxy'] |
| if 'all_proxy' in os.environ: |
| del os.environ['all_proxy'] |
| |
| |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| print(f"使用设备: {device}") |
| |
| |
| if not os.path.exists(model_file_path): |
| print(f"警告: 模型文件 {model_file_path} 不存在!") |
| else: |
| print(f"使用模型: {model_file_path}") |
| |
| app = create_interface() |
| |
| |
| if is_spaces: |
| |
| app.launch() |
| else: |
| |
| app.launch(server_name="0.0.0.0", server_port=7860, share=False) |