File size: 7,483 Bytes
a0c5b0b d739c25 b9cb794 a0c5b0b b9cb794 18e189f b9cb794 689811f 18e189f 9b94ed9 a0c5b0b 689811f 18e189f 9ee846a a0c5b0b d739c25 a0c5b0b d739c25 a0c5b0b 9ee846a d739c25 a0c5b0b 18e189f a0c5b0b 18e189f a0c5b0b 9ee846a 18e189f 9ee846a 18e189f 7ae1738 a0c5b0b b748ffa a0c5b0b b748ffa 7ae1738 18e189f b748ffa a0c5b0b b748ffa 18e189f a0c5b0b 18e189f 7ae1738 18e189f 9ee846a 18e189f 9ee846a 18e189f 9ee846a a0c5b0b 18e189f b748ffa a0c5b0b 18e189f 9ee846a d09e840 9ee846a 18e189f b9cb794 18e189f 9ee846a b748ffa a0c5b0b b748ffa 18e189f b748ffa 18e189f a0c5b0b 9ee846a 18e189f a0c5b0b d739c25 18e189f a0c5b0b b9cb794 9ee846a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
import os
import cv2
import numpy as np
import gradio as gr
import torch
from mouse_tracker import MouseTrackerAnalyzer
from huggingface_hub import hf_hub_download
# 检查是否在Hugging Face Spaces环境中
try:
import spaces
is_spaces = True
print("检测到 Hugging Face Spaces 环境")
except ImportError:
is_spaces = False
print("在本地环境运行")
# 全局配置
model_base_name = "fst-v1.3-n" # 模型基础名称,无后缀
total_frames = 0
# 根据后缀构造模型路径
def get_model_file_path(model_suffix):
return f"./{model_base_name}{model_suffix}"
# 从视频中提取特定帧
def extract_frame(video_path, frame_num):
if not video_path:
return None
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
return None
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num)
ret, frame = cap.read()
cap.release()
if not ret:
return None
return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# 选择视频文件
def select_video(video_file, model_suffix):
global total_frames
if not video_file:
return None, "请选择视频文件", gr.Slider(0,0,0), gr.Slider(0,0,0)
total_frames = int(cv2.VideoCapture(video_file).get(cv2.CAP_PROP_FRAME_COUNT))
# 读取首帧
cap = cv2.VideoCapture(video_file)
ret, frame = cap.read()
cap.release()
if not ret:
return None, "无法读取视频帧", gr.Slider(0,0,0), gr.Slider(0,0,0)
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# 更新滑块
start = gr.Slider(minimum=0, maximum=total_frames-1, value=0, step=1)
end = gr.Slider(minimum=0, maximum=total_frames-1, value=total_frames-1, step=1)
status = f"视频加载成功,总帧数: {total_frames}. 使用模型: {os.path.basename(get_model_file_path(model_suffix))}"
return frame_rgb, status, start, end
# 预览帧
def preview_frame(video_file, frame_num):
if not video_file:
return None, "请先选择视频文件"
frame = extract_frame(video_file, frame_num)
if frame is None:
return None, "无法读取指定帧"
return frame, f"帧 {frame_num}"
# 分析实现
def _start_analysis_impl(video, model_suffix, conf, iou, max_det, start_frame, end_frame, threshold):
if not video:
return None, None, "请选择视频文件"
if start_frame >= end_frame:
return None, None, "起始帧必须小于结束帧"
# 构造路径
video_name = os.path.splitext(os.path.basename(video))[0]
output_path = os.path.join(os.path.dirname(video), f"{video_name}_out.mp4")
csv_path = os.path.join(os.path.dirname(video), f"{video_name}_results.csv")
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_path = get_model_file_path(model_suffix)
if not os.path.exists(model_path):
if is_spaces:
try:
model_path = hf_hub_download(
repo_id="YOUR_HF_USERNAME/YOUR_REPO_NAME",
filename=f"weights/{model_base_name}{model_suffix}"
)
except Exception:
print(f"下载模型失败: {model_path}")
else:
print(f"警告: 本地未找到模型文件 {model_path}")
# 初始化分析器
analyzer = MouseTrackerAnalyzer(
model_path=model_path,
conf=conf,
iou=iou,
max_det=max_det,
verbose=True
)
analyzer.struggle_threshold = threshold
# 运行分析
analyzer.process_video(
video_path=video,
output_path=output_path,
start_frame=start_frame,
end_frame=end_frame,
callback=lambda prog, frm, res: print(f"进度: {prog}% 检测: {len(res)} 项")
)
analyzer.save_results(csv_path)
# 生成图表
plot_path = None
if analyzer.results:
plot_path = analyzer.generate_time_series_plot()
status = f"分析完成。视频: {output_path}, CSV: {csv_path}"
if plot_path:
status += f", 图表: {plot_path}"
return output_path, plot_path, status
# HF Spaces GPU 装饰
if is_spaces:
@spaces.GPU(duration=120)
def start_analysis(video, model_suffix, conf, iou, max_det, start_frame, end_frame, threshold):
return _start_analysis_impl(video, model_suffix, conf, iou, max_det, start_frame, end_frame, threshold)
else:
def start_analysis(video, model_suffix, conf, iou, max_det, start_frame, end_frame, threshold):
return _start_analysis_impl(video, model_suffix, conf, iou, max_det, start_frame, end_frame, threshold)
# 创建 Gradio 界面
def create_interface():
with gr.Blocks(title="鼠强迫游泳挣扎度分析") as app:
gr.Markdown("# 鼠强迫游泳测试挣扎度分析 (对象跟踪)")
with gr.Row():
with gr.Column(scale=1):
video_input = gr.Video(label="输入视频")
model_format = gr.Dropdown(
label="模型格式",
choices=[".onnx", ".engine", ".pt", ".mlpackage"],
value=".onnx",
interactive=True
)
device_info = gr.Textbox(
label="系统信息",
value=f"设备: {'GPU' if torch.cuda.is_available() else 'CPU'}",
interactive=False
)
conf = gr.Slider(0.1, 0.9, value=0.25, step=0.05, label="置信度阈值")
iou = gr.Slider(0.1, 0.9, value=0.45, step=0.05, label="IoU阈值")
max_det = gr.Slider(1, 50, value=20, step=1, label="最大检测数")
threshold = gr.Slider(0, 1, value=0.3, step=0.01, label="挣扎阈值")
start_frame = gr.Slider(0, 999999, value=0, step=1, label="起始帧")
end_frame = gr.Slider(0, 999999, value=999999, step=1, label="结束帧")
preview_btn = gr.Button("预览帧")
start_btn = gr.Button("开始分析", variant="primary")
with gr.Column(scale=2):
with gr.Tab("预览"):
preview_image = gr.Image(label="预览图像", type="numpy", height=400)
status_text = gr.Textbox(label="状态", interactive=False)
with gr.Tab("结果"):
output_video = gr.Video(label="分析结果视频")
result_plot = gr.Image(label="挣扎分数时间序列")
result_status = gr.Textbox(label="分析状态", interactive=False)
# 事件绑定,包含模型格式参数
video_input.change(select_video, inputs=[video_input, model_format], outputs=[preview_image, status_text, start_frame, end_frame])
preview_btn.click(preview_frame, inputs=[video_input, start_frame], outputs=[preview_image, status_text])
start_btn.click(
start_analysis,
inputs=[video_input, model_format, conf, iou, max_det, start_frame, end_frame, threshold],
outputs=[output_video, result_plot, result_status]
)
return app
if __name__ == "__main__":
# 清理代理
for key in ['http_proxy', 'https_proxy', 'all_proxy']:
os.environ.pop(key, None)
print(f"设备: {'GPU' if torch.cuda.is_available() else 'CPU'}")
print(f"默认模型路径: {get_model_file_path('.onnx')}")
app = create_interface()
if is_spaces:
app.launch()
else:
app.launch(server_name="0.0.0.0", server_port=7860, share=False)
|