|
|
import os |
|
|
import cv2 |
|
|
import numpy as np |
|
|
import gradio as gr |
|
|
import torch |
|
|
from mouse_tracker import MouseTrackerAnalyzer |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
|
|
|
try: |
|
|
import spaces |
|
|
is_spaces = True |
|
|
print("检测到 Hugging Face Spaces 环境") |
|
|
except ImportError: |
|
|
is_spaces = False |
|
|
print("在本地环境运行") |
|
|
|
|
|
|
|
|
model_base_name = "fst-v1.3-n" |
|
|
total_frames = 0 |
|
|
|
|
|
|
|
|
def get_model_file_path(model_suffix): |
|
|
return f"./{model_base_name}{model_suffix}" |
|
|
|
|
|
|
|
|
def extract_frame(video_path, frame_num): |
|
|
if not video_path: |
|
|
return None |
|
|
cap = cv2.VideoCapture(video_path) |
|
|
if not cap.isOpened(): |
|
|
return None |
|
|
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num) |
|
|
ret, frame = cap.read() |
|
|
cap.release() |
|
|
if not ret: |
|
|
return None |
|
|
return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
|
|
|
def select_video(video_file, model_suffix): |
|
|
global total_frames |
|
|
if not video_file: |
|
|
return None, "请选择视频文件", gr.Slider(0,0,0), gr.Slider(0,0,0) |
|
|
total_frames = int(cv2.VideoCapture(video_file).get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
|
|
|
cap = cv2.VideoCapture(video_file) |
|
|
ret, frame = cap.read() |
|
|
cap.release() |
|
|
if not ret: |
|
|
return None, "无法读取视频帧", gr.Slider(0,0,0), gr.Slider(0,0,0) |
|
|
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
start = gr.Slider(minimum=0, maximum=total_frames-1, value=0, step=1) |
|
|
end = gr.Slider(minimum=0, maximum=total_frames-1, value=total_frames-1, step=1) |
|
|
status = f"视频加载成功,总帧数: {total_frames}. 使用模型: {os.path.basename(get_model_file_path(model_suffix))}" |
|
|
return frame_rgb, status, start, end |
|
|
|
|
|
|
|
|
def preview_frame(video_file, frame_num): |
|
|
if not video_file: |
|
|
return None, "请先选择视频文件" |
|
|
frame = extract_frame(video_file, frame_num) |
|
|
if frame is None: |
|
|
return None, "无法读取指定帧" |
|
|
return frame, f"帧 {frame_num}" |
|
|
|
|
|
|
|
|
def _start_analysis_impl(video, model_suffix, conf, iou, max_det, start_frame, end_frame, threshold): |
|
|
if not video: |
|
|
return None, None, "请选择视频文件" |
|
|
if start_frame >= end_frame: |
|
|
return None, None, "起始帧必须小于结束帧" |
|
|
|
|
|
video_name = os.path.splitext(os.path.basename(video))[0] |
|
|
output_path = os.path.join(os.path.dirname(video), f"{video_name}_out.mp4") |
|
|
csv_path = os.path.join(os.path.dirname(video), f"{video_name}_results.csv") |
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
model_path = get_model_file_path(model_suffix) |
|
|
if not os.path.exists(model_path): |
|
|
if is_spaces: |
|
|
try: |
|
|
model_path = hf_hub_download( |
|
|
repo_id="YOUR_HF_USERNAME/YOUR_REPO_NAME", |
|
|
filename=f"weights/{model_base_name}{model_suffix}" |
|
|
) |
|
|
except Exception: |
|
|
print(f"下载模型失败: {model_path}") |
|
|
else: |
|
|
print(f"警告: 本地未找到模型文件 {model_path}") |
|
|
|
|
|
analyzer = MouseTrackerAnalyzer( |
|
|
model_path=model_path, |
|
|
conf=conf, |
|
|
iou=iou, |
|
|
max_det=max_det, |
|
|
verbose=True |
|
|
) |
|
|
analyzer.struggle_threshold = threshold |
|
|
|
|
|
analyzer.process_video( |
|
|
video_path=video, |
|
|
output_path=output_path, |
|
|
start_frame=start_frame, |
|
|
end_frame=end_frame, |
|
|
callback=lambda prog, frm, res: print(f"进度: {prog}% 检测: {len(res)} 项") |
|
|
) |
|
|
analyzer.save_results(csv_path) |
|
|
|
|
|
plot_path = None |
|
|
if analyzer.results: |
|
|
plot_path = analyzer.generate_time_series_plot() |
|
|
status = f"分析完成。视频: {output_path}, CSV: {csv_path}" |
|
|
if plot_path: |
|
|
status += f", 图表: {plot_path}" |
|
|
return output_path, plot_path, status |
|
|
|
|
|
|
|
|
if is_spaces: |
|
|
@spaces.GPU(duration=120) |
|
|
def start_analysis(video, model_suffix, conf, iou, max_det, start_frame, end_frame, threshold): |
|
|
return _start_analysis_impl(video, model_suffix, conf, iou, max_det, start_frame, end_frame, threshold) |
|
|
else: |
|
|
def start_analysis(video, model_suffix, conf, iou, max_det, start_frame, end_frame, threshold): |
|
|
return _start_analysis_impl(video, model_suffix, conf, iou, max_det, start_frame, end_frame, threshold) |
|
|
|
|
|
|
|
|
def create_interface(): |
|
|
with gr.Blocks(title="鼠强迫游泳挣扎度分析") as app: |
|
|
gr.Markdown("# 鼠强迫游泳测试挣扎度分析 (对象跟踪)") |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
video_input = gr.Video(label="输入视频") |
|
|
model_format = gr.Dropdown( |
|
|
label="模型格式", |
|
|
choices=[".onnx", ".engine", ".pt", ".mlpackage"], |
|
|
value=".onnx", |
|
|
interactive=True |
|
|
) |
|
|
device_info = gr.Textbox( |
|
|
label="系统信息", |
|
|
value=f"设备: {'GPU' if torch.cuda.is_available() else 'CPU'}", |
|
|
interactive=False |
|
|
) |
|
|
conf = gr.Slider(0.1, 0.9, value=0.25, step=0.05, label="置信度阈值") |
|
|
iou = gr.Slider(0.1, 0.9, value=0.45, step=0.05, label="IoU阈值") |
|
|
max_det = gr.Slider(1, 50, value=20, step=1, label="最大检测数") |
|
|
threshold = gr.Slider(0, 1, value=0.3, step=0.01, label="挣扎阈值") |
|
|
start_frame = gr.Slider(0, 999999, value=0, step=1, label="起始帧") |
|
|
end_frame = gr.Slider(0, 999999, value=999999, step=1, label="结束帧") |
|
|
preview_btn = gr.Button("预览帧") |
|
|
start_btn = gr.Button("开始分析", variant="primary") |
|
|
with gr.Column(scale=2): |
|
|
with gr.Tab("预览"): |
|
|
preview_image = gr.Image(label="预览图像", type="numpy", height=400) |
|
|
status_text = gr.Textbox(label="状态", interactive=False) |
|
|
with gr.Tab("结果"): |
|
|
output_video = gr.Video(label="分析结果视频") |
|
|
result_plot = gr.Image(label="挣扎分数时间序列") |
|
|
result_status = gr.Textbox(label="分析状态", interactive=False) |
|
|
|
|
|
video_input.change(select_video, inputs=[video_input, model_format], outputs=[preview_image, status_text, start_frame, end_frame]) |
|
|
preview_btn.click(preview_frame, inputs=[video_input, start_frame], outputs=[preview_image, status_text]) |
|
|
start_btn.click( |
|
|
start_analysis, |
|
|
inputs=[video_input, model_format, conf, iou, max_det, start_frame, end_frame, threshold], |
|
|
outputs=[output_video, result_plot, result_status] |
|
|
) |
|
|
return app |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
for key in ['http_proxy', 'https_proxy', 'all_proxy']: |
|
|
os.environ.pop(key, None) |
|
|
print(f"设备: {'GPU' if torch.cuda.is_available() else 'CPU'}") |
|
|
print(f"默认模型路径: {get_model_file_path('.onnx')}") |
|
|
app = create_interface() |
|
|
if is_spaces: |
|
|
app.launch() |
|
|
else: |
|
|
app.launch(server_name="0.0.0.0", server_port=7860, share=False) |
|
|
|