FST / app.py
Hakureirm's picture
Update app.py
bee9d8f verified
raw
history blame
12.7 kB
import os
import cv2
import numpy as np
import gradio as gr
import tempfile
import torch
from mouse_tracker import MouseTrackerAnalyzer
import huggingface_hub
from huggingface_hub import hf_hub_download
# 检查是否在Hugging Face Spaces环境中
try:
import spaces
is_spaces = True
print("检测到Hugging Face Spaces环境")
except ImportError:
is_spaces = False
print("在本地环境运行")
# 全局变量
analyzer = None
video_file_path = None
model_file_path = "./fst-v1.2-n.engine" # 直接指定模型文件路径
total_frames = 0
output_path = None
# 从视频中提取特定帧
def extract_frame(video_path, frame_num):
"""从视频中提取特定帧"""
if not video_path:
return None
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
return None
# 设置帧位置
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num)
# 读取帧
ret, frame = cap.read()
cap.release()
if not ret:
return None
# 转换为RGB格式
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
return frame_rgb
# 选择视频文件
def select_video(video_file):
global video_file_path, total_frames
if not video_file:
return None, "请选择视频文件", gr.Slider(minimum=0, maximum=0, value=0), gr.Slider(minimum=0, maximum=0, value=0)
video_file_path = video_file
# 获取视频总帧数
cap = cv2.VideoCapture(video_file_path)
if not cap.isOpened():
return None, "无法打开视频文件", gr.Slider(minimum=0, maximum=0, value=0), gr.Slider(minimum=0, maximum=0, value=0)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
# 提取第一帧
ret, first_frame = cap.read()
cap.release()
if not ret:
return None, "无法读取视频帧", gr.Slider(minimum=0, maximum=0, value=0), gr.Slider(minimum=0, maximum=0, value=0)
# 转为RGB
first_frame_rgb = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
# 更新帧滑块
start_slider = gr.Slider(minimum=0, maximum=total_frames-1, value=0, step=1)
end_slider = gr.Slider(minimum=0, maximum=total_frames-1, value=total_frames-1, step=1)
model_status = f"使用模型: {os.path.basename(model_file_path)}"
return first_frame_rgb, f"视频加载成功,总帧数: {total_frames}. {model_status}", start_slider, end_slider
# 预览帧
def preview_frame(video_file, frame_num):
if not video_file:
return None, "请先选择视频文件"
# 从视频提取帧
frame = extract_frame(video_file, frame_num)
if frame is None:
return None, "无法读取指定帧"
return frame, f"帧 {frame_num}"
# 开始分析
# 为HF Spaces环境添加GPU装饰器
if is_spaces:
@spaces.GPU(duration=120) # 申请GPU资源,持续120秒
def start_analysis(video, conf, iou, max_det, start_frame, end_frame, threshold):
return _start_analysis_impl(video, conf, iou, max_det, start_frame, end_frame, threshold)
else:
def start_analysis(video, conf, iou, max_det, start_frame, end_frame, threshold):
return _start_analysis_impl(video, conf, iou, max_det, start_frame, end_frame, threshold)
# 实际的分析实现
def _start_analysis_impl(video, conf, iou, max_det, start_frame, end_frame, threshold):
global analyzer, output_path, model_file_path
if not video:
return None, None, "请选择视频文件"
if start_frame >= end_frame:
return None, None, "起始帧必须小于结束帧"
# 创建输出路径
video_name = os.path.splitext(os.path.basename(video))[0]
output_path = os.path.join(os.path.dirname(video), f"{video_name}_out.mp4")
csv_path = os.path.join(os.path.dirname(video), f"{video_name}_results.csv")
try:
# 检查设备
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"使用设备: {device}")
# 确保模型文件存在
if not os.path.exists(model_file_path):
# 如果在Hugging Face Spaces环境中,尝试从Hub下载模型
if is_spaces:
try:
print(f"尝试从Hugging Face Hub下载模型: {os.path.basename(model_file_path)}")
model_file_path = hf_hub_download(
repo_id="YOUR_HF_USERNAME/YOUR_REPO_NAME", # 替换为您的仓库
filename="weights/fst-v1.2-n.onnx"
)
print(f"模型已下载到: {model_file_path}")
except Exception as e:
print(f"从Hub下载模型失败: {str(e)}")
else:
print(f"警告: 模型文件 {model_file_path} 不存在!")
# 创建分析器
analyzer = MouseTrackerAnalyzer(
model_path=model_file_path,
conf=conf,
iou=iou,
max_det=max_det,
verbose=True # 开启详细日志
)
analyzer.struggle_threshold = threshold
# 处理视频的进度回调
def progress_update(progress, frame, results):
print(f"处理进度: {progress}%, 检测到对象数: {len(results)}")
print(f"处理视频: {video}")
print(f"输出路径: {output_path}")
print(f"参数: conf={conf}, iou={iou}, max_det={max_det}, threshold={threshold}")
print(f"使用模型: {model_file_path}")
# 提取视频帧数范围并分析
results = analyzer.process_video(
video_path=video,
output_path=output_path,
start_frame=start_frame,
end_frame=end_frame,
callback=progress_update
)
# 保存结果到CSV
print(f"保存结果到CSV: {csv_path}")
analyzer.save_results(csv_path)
print(f"结果已保存到CSV,共 {len(analyzer.results)} 帧数据")
# 生成分析图表
print("生成时间序列图...")
if len(analyzer.results) == 0:
print("警告: 没有可用于绘图的结果!")
plot_path = None
else:
plot_path = analyzer.generate_time_series_plot()
if plot_path and os.path.exists(plot_path):
print(f"图表已生成并保存到: {plot_path}, 大小: {os.path.getsize(plot_path)/1024:.2f}KB")
else:
print(f"生成图表失败或图表文件不存在!")
plot_path = None
# 检查输出文件是否存在
if os.path.exists(output_path):
file_size = os.path.getsize(output_path) / (1024 * 1024) # MB
print(f"输出视频大小: {file_size:.2f}MB")
# 处理debug帧
debug_frame_path = os.path.join(os.path.dirname(output_path), "debug_frame.jpg")
if os.path.exists(debug_frame_path):
print(f"调试帧保存在: {debug_frame_path}")
if plot_path and os.path.exists(plot_path):
print(f"图表文件存在于: {plot_path}, 大小: {os.path.getsize(plot_path)/1024:.2f}KB")
# 确保返回正确的文件路径
status_message = "分析完成。"
if os.path.exists(output_path):
status_message += f"视频已保存。"
else:
status_message += "警告: 未找到输出视频。"
if plot_path and os.path.exists(plot_path):
status_message += f" 时间序列图已生成。"
else:
status_message += " 警告: 生成时间序列图失败。"
status_message += f" 结果已保存到: {csv_path}"
return output_path, plot_path, status_message
except Exception as e:
import traceback
traceback.print_exc()
return None, None, f"处理错误: {str(e)}"
# 创建Gradio界面
def create_interface():
with gr.Blocks(title="鼠强迫游泳挣扎度分析 - 对象跟踪") as app:
gr.Markdown("# 鼠强迫游泳测试挣扎度分析 (对象跟踪)")
with gr.Row():
with gr.Column(scale=1):
# 只保留视频选择,移除模型选择
video_input = gr.Video(label="输入视频")
# 显示当前使用的模型和设备信息
device_info = "GPU" if torch.cuda.is_available() else "CPU"
model_info = gr.Textbox(
label="系统信息",
value=f"使用模型: {os.path.basename(model_file_path)} | 计算设备: {device_info}",
interactive=False
)
# 参数设置
with gr.Row():
conf = gr.Slider(minimum=0.1, maximum=0.9, value=0.25, step=0.05, label="置信度阈值")
iou = gr.Slider(minimum=0.1, maximum=0.9, value=0.45, step=0.05, label="IoU阈值")
with gr.Row():
max_det = gr.Slider(minimum=1, maximum=50, value=20, step=1, label="最大检测数")
threshold = gr.Slider(minimum=0, maximum=1, value=0.3, step=0.01, label="挣扎阈值")
# 帧选择
start_frame = gr.Slider(minimum=0, maximum=999999, value=0, step=1, label="起始帧")
end_frame = gr.Slider(minimum=0, maximum=999999, value=999999, step=1, label="结束帧")
# 预览按钮
preview_btn = gr.Button("预览帧")
# 开始分析
start_btn = gr.Button("开始分析", variant="primary")
with gr.Column(scale=2):
# 显示区域
with gr.Tab("预览"):
# 图像预览
preview_image = gr.Image(label="预览图像", type="numpy", height=400)
status_text = gr.Textbox(label="状态", interactive=False)
gr.Markdown("""
### 使用说明:
1. 选择一个视频文件
2. 调整参数
- 置信度阈值: 对象检测的最低置信度,较低的值会检测更多潜在对象
- IoU阈值: 用于过滤重叠检测
- 最大检测数: 每帧检测的最大对象数
- 挣扎阈值: 分类为挣扎状态的最低分数
3. 设置帧范围
4. 点击"开始分析"按钮
系统将自动跟踪小鼠并分析其挣扎行为,无需手动定义区域
""")
with gr.Tab("结果"):
with gr.Row():
output_video = gr.Video(label="分析结果视频")
result_plot = gr.Image(label="挣扎分数时间序列")
result_status = gr.Textbox(label="分析状态", interactive=False)
# 绑定事件
video_input.change(select_video, inputs=[video_input], outputs=[preview_image, status_text, start_frame, end_frame])
preview_btn.click(preview_frame, inputs=[video_input, start_frame], outputs=[preview_image, status_text])
start_btn.click(
start_analysis,
inputs=[video_input, conf, iou, max_det, start_frame, end_frame, threshold],
outputs=[output_video, result_plot, result_status]
)
return app
# 启动应用
if __name__ == "__main__":
# 清除可能干扰的代理设置
if 'http_proxy' in os.environ:
del os.environ['http_proxy']
if 'https_proxy' in os.environ:
del os.environ['https_proxy']
if 'all_proxy' in os.environ:
del os.environ['all_proxy']
# 检查设备和模型
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"使用设备: {device}")
# 检查模型文件是否存在
if not os.path.exists(model_file_path):
print(f"警告: 模型文件 {model_file_path} 不存在!")
else:
print(f"使用模型: {model_file_path}")
app = create_interface()
# 根据环境决定启动方式
if is_spaces:
# Hugging Face Spaces环境中的启动方式
app.launch()
else:
# 本地环境的启动方式
app.launch(server_name="0.0.0.0", server_port=7860, share=False)