File size: 7,424 Bytes
f8ba0eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import os
import uuid
import time
import psutil
import torch
import cv2
import shutil
from models.qwen import Qwen2VL
from models.gemma import Gemma
from models.minicpm import MiniCPM
from models.lfm import LFM2
from video_processor import extract_frames, FrameSamplingMethod
import argparse
import json
import logging
from tqdm import tqdm
TEMP_VIDEO_DIR = "temp_videos"
def process_video(model, video_path, prompt, sampling_method_str="CONTENT_AWARE", sampling_rate=5):
    """
    直接处理视频和文本提示,进行推理并返回结果。
    
    Args:
        video_path (str): 视频文件路径
        prompt (str): 文本提示
        sampling_method_str (str): 采样方法字符串
        sampling_rate (int): 采样率或阈值
    
    Returns:
        dict: 推理结果
    """
    request_start_time = time.time()
    request_id = str(uuid.uuid4())
    logging.info(f"[{request_id}] Processing video: '{video_path}', Prompt: '{prompt}'")

    # 验证视频文件
    if not os.path.exists(video_path):
        raise FileNotFoundError(f"Video file not found: {video_path}")
    
    if not video_path.lower().endswith(('.mp4', '.avi', '.mov', '.mkv')):
        logging.warning(f"[{request_id}] File '{video_path}' may not be a video file.")

    # 转换采样方法字符串为枚举
    sampling_method_map = {
        "CONTENT_AWARE": FrameSamplingMethod.CONTENT_AWARE,
        "UNIFORM": FrameSamplingMethod.UNIFORM,
    }
    sampling_method = sampling_method_map.get(sampling_method_str, FrameSamplingMethod.CONTENT_AWARE)
    
    # 创建临时目录
    temp_frame_dir = os.path.join(TEMP_VIDEO_DIR, request_id)
    os.makedirs(temp_frame_dir, exist_ok=True)

    try:
        logging.info(f"[{request_id}] Extracting frames using method: {sampling_method.value}, rate/threshold: {sampling_rate}")
        
        frames = extract_frames(video_path, sampling_method, sampling_rate)
        if not frames:
            raise ValueError(f"Could not extract any frames from the video: {video_path}")
        
        logging.info(f"[{request_id}] Extracted {len(frames)} frames successfully. Saving to temporary files...")

        # 将帧保存到临时文件并获取其路径
        frame_paths = []
        for i, frame in enumerate(frames):
            frame_path = os.path.join(temp_frame_dir, f"frame_{i:04d}.jpg")
            cv2.imwrite(frame_path, frame)
            abs_frame_path = os.path.abspath(frame_path)
            frame_paths.append(abs_frame_path)

        logging.info(f"[{request_id}] {len(frame_paths)} frames saved to {temp_frame_dir}")

        # 进行推理
        output = model.generate(frame_paths, prompt)

        logging.info(f"Tokens per second: {output['tokens_per_second']}, Peak GPU memory MB: {output['peak_gpu_memory_mb']}")

        inference_end_time = time.time()
        cpu_usage = psutil.cpu_percent(interval=None)
        cpu_core_utilization = psutil.cpu_percent(interval=None, percpu=True)
        logging.info(f"[{request_id}] Inference time: {inference_end_time - request_start_time:.2f} seconds, CPU usage: {cpu_usage}%, CPU core utilization: {cpu_core_utilization}")
        
        # 添加性能指标到输出
        output["inference_time"] = inference_end_time - request_start_time
        output["cpu_usage"] = cpu_usage
        output["cpu_core_utilization"] = cpu_core_utilization
        output["num_generated_tokens"] = output["num_generated_tokens"]
        output["request_id"] = request_id
        
        return output

    except Exception as e:
        logging.error(f"[{request_id}] An error occurred during processing: {str(e)}", exc_info=True)
        raise e
    finally:
        # 清理临时文件
        if os.path.exists(temp_frame_dir):
            shutil.rmtree(temp_frame_dir)
            logging.info(f"[{request_id}] Cleaned up temporary frame directory: {temp_frame_dir}")


def main():
    """主函数"""
    try:
        parser = argparse.ArgumentParser()
        parser.add_argument("--model_path", type=str, default="Qwen/Qwen2.5-VL-3B-Instruct-AWQ")
        parser.add_argument("--video_dir", type=str, default="videos", help="视频")
        parser.add_argument("--prompt", type=str, default="Summarize the key observable events in this 1-minute convenience store video clip. Focus strictly on the physical actions and interactions of the people. Describe only what you can see; do not interpret intentions, relationships, or work efficiency. Avoid all repetitive descriptions of the store's layout or shelves.", help="文本提示")
        parser.add_argument("--sampling_method", type=str, default="UNIFORM", 
                        choices=["CONTENT_AWARE", "UNIFORM", "RANDOM"], 
                        help="帧采样方法")
        parser.add_argument("--sampling_rate", type=int, default=30, help="采样率或阈值")
        args = parser.parse_args()


        # --- 日志和临时文件目录配置 ---
        LOG_DIR = f"logs/{args.model_path.split('/')[-1]}"
        OUTPUT_DIR = f"outputs/{args.model_path.split('/')[-1]}"
        os.makedirs(LOG_DIR, exist_ok=True)
        os.makedirs(OUTPUT_DIR, exist_ok=True)
        os.makedirs(TEMP_VIDEO_DIR, exist_ok=True)
        start_time = time.strftime('%Y%m%d_%H%M%S')
        log_filename = f"{LOG_DIR}/{start_time}.log"
        logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S', filename=log_filename, filemode='a')

        # --- 加载模型和处理器 ---
        logging.info(f"Loading model: {args.model_path}")
        model_load_start = time.time()
        if "qwen" in args.model_path.lower():
            model = Qwen2VL(args.model_path)
        elif "gemma" in args.model_path.lower():
            model = Gemma(args.model_path)
        elif "minicpm" in args.model_path.lower():
            model = MiniCPM(args.model_path)
        elif "lfm" in args.model_path.lower():
            model = LFM2(args.model_path)
        model_load_end = time.time()
        GPU_MEMORY_USAGE = f"{torch.cuda.memory_allocated(0)/1024**2:.2f} MB" if torch.cuda.is_available() else "N/A"
        logging.info(f"Model loaded in {model_load_end - model_load_start:.2f} seconds")
        logging.info(f"GPU Memory Usage after model load: {GPU_MEMORY_USAGE}")
        # 处理视频
        total_output = {}
        for video_path in tqdm(os.listdir(args.video_dir)):
            result = process_video(
                model=model,
                video_path=os.path.join(args.video_dir, video_path),
                prompt=args.prompt,
                sampling_method_str=args.sampling_method,
                sampling_rate=args.sampling_rate
            )
            total_output[video_path] = result
            # 保存结果到文件
            output_filename = f"{OUTPUT_DIR}/{start_time}.json"
            with open(output_filename, 'w', encoding='utf-8') as f:
                json.dump(total_output, f, ensure_ascii=False, indent=2)
            
            print(f"处理完成!结果已保存到: {output_filename}")
            print(f"推理时间: {result['inference_time']:.2f} 秒")
            print(f"生成的内容: {result.get('generated_text', 'N/A')}")
        
    except Exception as e:
        logging.error(f"处理失败: {str(e)}", exc_info=True)
        print(f"处理失败: {str(e)}")


if __name__ == "__main__":
    main()