File size: 7,483 Bytes
a0c5b0b
d739c25
 
 
b9cb794
a0c5b0b
b9cb794
 
 
 
 
 
18e189f
b9cb794
 
 
689811f
18e189f
9b94ed9
a0c5b0b
689811f
18e189f
 
9ee846a
 
a0c5b0b
 
 
 
d739c25
a0c5b0b
 
 
 
d739c25
a0c5b0b
 
9ee846a
d739c25
a0c5b0b
18e189f
 
a0c5b0b
18e189f
 
 
 
 
a0c5b0b
 
9ee846a
18e189f
 
9ee846a
 
18e189f
 
7ae1738
a0c5b0b
 
 
b748ffa
a0c5b0b
 
b748ffa
 
7ae1738
18e189f
 
b748ffa
 
a0c5b0b
b748ffa
18e189f
a0c5b0b
 
 
18e189f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ae1738
18e189f
9ee846a
 
18e189f
 
9ee846a
18e189f
 
9ee846a
 
a0c5b0b
18e189f
b748ffa
a0c5b0b
 
18e189f
9ee846a
 
 
d09e840
9ee846a
 
 
 
18e189f
b9cb794
 
18e189f
 
 
 
9ee846a
 
b748ffa
 
a0c5b0b
b748ffa
 
 
 
18e189f
 
b748ffa
18e189f
 
a0c5b0b
 
9ee846a
18e189f
a0c5b0b
 
 
d739c25
 
18e189f
 
 
 
 
a0c5b0b
b9cb794
 
 
9ee846a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import os
import cv2
import numpy as np
import gradio as gr
import torch
from mouse_tracker import MouseTrackerAnalyzer
from huggingface_hub import hf_hub_download

# 检查是否在Hugging Face Spaces环境中
try:
    import spaces
    is_spaces = True
    print("检测到 Hugging Face Spaces 环境")
except ImportError:
    is_spaces = False
    print("在本地环境运行")

# 全局配置
model_base_name = "fst-v1.3-n"  # 模型基础名称,无后缀
total_frames = 0

# 根据后缀构造模型路径
def get_model_file_path(model_suffix):
    return f"./{model_base_name}{model_suffix}"

# 从视频中提取特定帧
def extract_frame(video_path, frame_num):
    if not video_path:
        return None
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        return None
    cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num)
    ret, frame = cap.read()
    cap.release()
    if not ret:
        return None
    return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

# 选择视频文件
def select_video(video_file, model_suffix):
    global total_frames
    if not video_file:
        return None, "请选择视频文件", gr.Slider(0,0,0), gr.Slider(0,0,0)
    total_frames = int(cv2.VideoCapture(video_file).get(cv2.CAP_PROP_FRAME_COUNT))
    # 读取首帧
    cap = cv2.VideoCapture(video_file)
    ret, frame = cap.read()
    cap.release()
    if not ret:
        return None, "无法读取视频帧", gr.Slider(0,0,0), gr.Slider(0,0,0)
    frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    # 更新滑块
    start = gr.Slider(minimum=0, maximum=total_frames-1, value=0, step=1)
    end = gr.Slider(minimum=0, maximum=total_frames-1, value=total_frames-1, step=1)
    status = f"视频加载成功,总帧数: {total_frames}. 使用模型: {os.path.basename(get_model_file_path(model_suffix))}"
    return frame_rgb, status, start, end

# 预览帧
def preview_frame(video_file, frame_num):
    if not video_file:
        return None, "请先选择视频文件"
    frame = extract_frame(video_file, frame_num)
    if frame is None:
        return None, "无法读取指定帧"
    return frame, f"帧 {frame_num}"

# 分析实现
def _start_analysis_impl(video, model_suffix, conf, iou, max_det, start_frame, end_frame, threshold):
    if not video:
        return None, None, "请选择视频文件"
    if start_frame >= end_frame:
        return None, None, "起始帧必须小于结束帧"
    # 构造路径
    video_name = os.path.splitext(os.path.basename(video))[0]
    output_path = os.path.join(os.path.dirname(video), f"{video_name}_out.mp4")
    csv_path = os.path.join(os.path.dirname(video), f"{video_name}_results.csv")
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model_path = get_model_file_path(model_suffix)
    if not os.path.exists(model_path):
        if is_spaces:
            try:
                model_path = hf_hub_download(
                    repo_id="YOUR_HF_USERNAME/YOUR_REPO_NAME",
                    filename=f"weights/{model_base_name}{model_suffix}"
                )
            except Exception:
                print(f"下载模型失败: {model_path}")
        else:
            print(f"警告: 本地未找到模型文件 {model_path}")
    # 初始化分析器
    analyzer = MouseTrackerAnalyzer(
        model_path=model_path,
        conf=conf,
        iou=iou,
        max_det=max_det,
        verbose=True
    )
    analyzer.struggle_threshold = threshold
    # 运行分析
    analyzer.process_video(
        video_path=video,
        output_path=output_path,
        start_frame=start_frame,
        end_frame=end_frame,
        callback=lambda prog, frm, res: print(f"进度: {prog}% 检测: {len(res)} 项")
    )
    analyzer.save_results(csv_path)
    # 生成图表
    plot_path = None
    if analyzer.results:
        plot_path = analyzer.generate_time_series_plot()
    status = f"分析完成。视频: {output_path}, CSV: {csv_path}"
    if plot_path:
        status += f", 图表: {plot_path}"
    return output_path, plot_path, status

# HF Spaces GPU 装饰
if is_spaces:
    @spaces.GPU(duration=120)
    def start_analysis(video, model_suffix, conf, iou, max_det, start_frame, end_frame, threshold):
        return _start_analysis_impl(video, model_suffix, conf, iou, max_det, start_frame, end_frame, threshold)
else:
    def start_analysis(video, model_suffix, conf, iou, max_det, start_frame, end_frame, threshold):
        return _start_analysis_impl(video, model_suffix, conf, iou, max_det, start_frame, end_frame, threshold)

# 创建 Gradio 界面
def create_interface():
    with gr.Blocks(title="鼠强迫游泳挣扎度分析") as app:
        gr.Markdown("# 鼠强迫游泳测试挣扎度分析 (对象跟踪)")
        with gr.Row():
            with gr.Column(scale=1):
                video_input = gr.Video(label="输入视频")
                model_format = gr.Dropdown(
                    label="模型格式",
                    choices=[".onnx", ".engine", ".pt", ".mlpackage"],
                    value=".onnx",
                    interactive=True
                )
                device_info = gr.Textbox(
                    label="系统信息",
                    value=f"设备: {'GPU' if torch.cuda.is_available() else 'CPU'}",
                    interactive=False
                )
                conf = gr.Slider(0.1, 0.9, value=0.25, step=0.05, label="置信度阈值")
                iou = gr.Slider(0.1, 0.9, value=0.45, step=0.05, label="IoU阈值")
                max_det = gr.Slider(1, 50, value=20, step=1, label="最大检测数")
                threshold = gr.Slider(0, 1, value=0.3, step=0.01, label="挣扎阈值")
                start_frame = gr.Slider(0, 999999, value=0, step=1, label="起始帧")
                end_frame = gr.Slider(0, 999999, value=999999, step=1, label="结束帧")
                preview_btn = gr.Button("预览帧")
                start_btn = gr.Button("开始分析", variant="primary")
            with gr.Column(scale=2):
                with gr.Tab("预览"):
                    preview_image = gr.Image(label="预览图像", type="numpy", height=400)
                    status_text = gr.Textbox(label="状态", interactive=False)
                with gr.Tab("结果"):
                    output_video = gr.Video(label="分析结果视频")
                    result_plot = gr.Image(label="挣扎分数时间序列")
                    result_status = gr.Textbox(label="分析状态", interactive=False)
        # 事件绑定,包含模型格式参数
        video_input.change(select_video, inputs=[video_input, model_format], outputs=[preview_image, status_text, start_frame, end_frame])
        preview_btn.click(preview_frame, inputs=[video_input, start_frame], outputs=[preview_image, status_text])
        start_btn.click(
            start_analysis,
            inputs=[video_input, model_format, conf, iou, max_det, start_frame, end_frame, threshold],
            outputs=[output_video, result_plot, result_status]
        )
    return app

if __name__ == "__main__":
    # 清理代理
    for key in ['http_proxy', 'https_proxy', 'all_proxy']:
        os.environ.pop(key, None)
    print(f"设备: {'GPU' if torch.cuda.is_available() else 'CPU'}")
    print(f"默认模型路径: {get_model_file_path('.onnx')}")
    app = create_interface()
    if is_spaces:
        app.launch()
    else:
        app.launch(server_name="0.0.0.0", server_port=7860, share=False)