| import os |
| import uuid |
| import time |
| import psutil |
| import torch |
| import cv2 |
| import shutil |
| from models.qwen import Qwen2VL |
| from models.gemma import Gemma |
| from models.minicpm import MiniCPM |
| from models.lfm import LFM2 |
| from video_processor import extract_frames, FrameSamplingMethod |
| import argparse |
| import json |
| import logging |
| from tqdm import tqdm |
| TEMP_VIDEO_DIR = "temp_videos" |
| def process_video(model, video_path, prompt, sampling_method_str="CONTENT_AWARE", sampling_rate=5): |
| """ |
| 直接处理视频和文本提示,进行推理并返回结果。 |
| |
| Args: |
| video_path (str): 视频文件路径 |
| prompt (str): 文本提示 |
| sampling_method_str (str): 采样方法字符串 |
| sampling_rate (int): 采样率或阈值 |
| |
| Returns: |
| dict: 推理结果 |
| """ |
| request_start_time = time.time() |
| request_id = str(uuid.uuid4()) |
| logging.info(f"[{request_id}] Processing video: '{video_path}', Prompt: '{prompt}'") |
|
|
| |
| if not os.path.exists(video_path): |
| raise FileNotFoundError(f"Video file not found: {video_path}") |
| |
| if not video_path.lower().endswith(('.mp4', '.avi', '.mov', '.mkv')): |
| logging.warning(f"[{request_id}] File '{video_path}' may not be a video file.") |
|
|
| |
| sampling_method_map = { |
| "CONTENT_AWARE": FrameSamplingMethod.CONTENT_AWARE, |
| "UNIFORM": FrameSamplingMethod.UNIFORM, |
| } |
| sampling_method = sampling_method_map.get(sampling_method_str, FrameSamplingMethod.CONTENT_AWARE) |
| |
| |
| temp_frame_dir = os.path.join(TEMP_VIDEO_DIR, request_id) |
| os.makedirs(temp_frame_dir, exist_ok=True) |
|
|
| try: |
| logging.info(f"[{request_id}] Extracting frames using method: {sampling_method.value}, rate/threshold: {sampling_rate}") |
| |
| frames = extract_frames(video_path, sampling_method, sampling_rate) |
| if not frames: |
| raise ValueError(f"Could not extract any frames from the video: {video_path}") |
| |
| logging.info(f"[{request_id}] Extracted {len(frames)} frames successfully. Saving to temporary files...") |
|
|
| |
| frame_paths = [] |
| for i, frame in enumerate(frames): |
| frame_path = os.path.join(temp_frame_dir, f"frame_{i:04d}.jpg") |
| cv2.imwrite(frame_path, frame) |
| abs_frame_path = os.path.abspath(frame_path) |
| frame_paths.append(abs_frame_path) |
|
|
| logging.info(f"[{request_id}] {len(frame_paths)} frames saved to {temp_frame_dir}") |
|
|
| |
| output = model.generate(frame_paths, prompt) |
|
|
| logging.info(f"Tokens per second: {output['tokens_per_second']}, Peak GPU memory MB: {output['peak_gpu_memory_mb']}") |
|
|
| inference_end_time = time.time() |
| cpu_usage = psutil.cpu_percent(interval=None) |
| cpu_core_utilization = psutil.cpu_percent(interval=None, percpu=True) |
| logging.info(f"[{request_id}] Inference time: {inference_end_time - request_start_time:.2f} seconds, CPU usage: {cpu_usage}%, CPU core utilization: {cpu_core_utilization}") |
| |
| |
| output["inference_time"] = inference_end_time - request_start_time |
| output["cpu_usage"] = cpu_usage |
| output["cpu_core_utilization"] = cpu_core_utilization |
| output["num_generated_tokens"] = output["num_generated_tokens"] |
| output["request_id"] = request_id |
| |
| return output |
|
|
| except Exception as e: |
| logging.error(f"[{request_id}] An error occurred during processing: {str(e)}", exc_info=True) |
| raise e |
| finally: |
| |
| if os.path.exists(temp_frame_dir): |
| shutil.rmtree(temp_frame_dir) |
| logging.info(f"[{request_id}] Cleaned up temporary frame directory: {temp_frame_dir}") |
|
|
|
|
| def main(): |
| """主函数""" |
| try: |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--model_path", type=str, default="Qwen/Qwen2.5-VL-3B-Instruct-AWQ") |
| parser.add_argument("--video_dir", type=str, default="videos", help="视频") |
| parser.add_argument("--prompt", type=str, default="Summarize the key observable events in this 1-minute convenience store video clip. Focus strictly on the physical actions and interactions of the people. Describe only what you can see; do not interpret intentions, relationships, or work efficiency. Avoid all repetitive descriptions of the store's layout or shelves.", help="文本提示") |
| parser.add_argument("--sampling_method", type=str, default="UNIFORM", |
| choices=["CONTENT_AWARE", "UNIFORM", "RANDOM"], |
| help="帧采样方法") |
| parser.add_argument("--sampling_rate", type=int, default=30, help="采样率或阈值") |
| args = parser.parse_args() |
|
|
|
|
| |
| LOG_DIR = f"logs/{args.model_path.split('/')[-1]}" |
| OUTPUT_DIR = f"outputs/{args.model_path.split('/')[-1]}" |
| os.makedirs(LOG_DIR, exist_ok=True) |
| os.makedirs(OUTPUT_DIR, exist_ok=True) |
| os.makedirs(TEMP_VIDEO_DIR, exist_ok=True) |
| start_time = time.strftime('%Y%m%d_%H%M%S') |
| log_filename = f"{LOG_DIR}/{start_time}.log" |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S', filename=log_filename, filemode='a') |
|
|
| |
| logging.info(f"Loading model: {args.model_path}") |
| model_load_start = time.time() |
| if "qwen" in args.model_path.lower(): |
| model = Qwen2VL(args.model_path) |
| elif "gemma" in args.model_path.lower(): |
| model = Gemma(args.model_path) |
| elif "minicpm" in args.model_path.lower(): |
| model = MiniCPM(args.model_path) |
| elif "lfm" in args.model_path.lower(): |
| model = LFM2(args.model_path) |
| model_load_end = time.time() |
| GPU_MEMORY_USAGE = f"{torch.cuda.memory_allocated(0)/1024**2:.2f} MB" if torch.cuda.is_available() else "N/A" |
| logging.info(f"Model loaded in {model_load_end - model_load_start:.2f} seconds") |
| logging.info(f"GPU Memory Usage after model load: {GPU_MEMORY_USAGE}") |
| |
| total_output = {} |
| for video_path in tqdm(os.listdir(args.video_dir)): |
| result = process_video( |
| model=model, |
| video_path=os.path.join(args.video_dir, video_path), |
| prompt=args.prompt, |
| sampling_method_str=args.sampling_method, |
| sampling_rate=args.sampling_rate |
| ) |
| total_output[video_path] = result |
| |
| output_filename = f"{OUTPUT_DIR}/{start_time}.json" |
| with open(output_filename, 'w', encoding='utf-8') as f: |
| json.dump(total_output, f, ensure_ascii=False, indent=2) |
| |
| print(f"处理完成!结果已保存到: {output_filename}") |
| print(f"推理时间: {result['inference_time']:.2f} 秒") |
| print(f"生成的内容: {result.get('generated_text', 'N/A')}") |
| |
| except Exception as e: |
| logging.error(f"处理失败: {str(e)}", exc_info=True) |
| print(f"处理失败: {str(e)}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|