Spaces:
Build error
Build error
| import cv2 | |
| import torch | |
| import torchvision.models as models | |
| import torchvision.transforms as transforms | |
| import numpy as np | |
| import pickle | |
| from PIL import Image | |
| import os | |
| def load_model_vgg19(): | |
| # 加载预训练的 VGG19 模型 | |
| vgg19 = models.vgg19(pretrained=True) | |
| # structure = torch.nn.Sequential(*list(vgg19.children())[:]) | |
| # # 查看整体结构 | |
| # print(structure) | |
| # # 查看模型各部分名称 ['features', 'avgpool', 'classifier'] | |
| # print('模型各部分名称', vgg19._modules.keys()) | |
| # # 原始feature | |
| # features = torch.nn.Sequential(*list(vgg19.children())[0]) | |
| # # 原始classifier | |
| # classifier = torch.nn.Sequential(*list(vgg19.children())[-1]) | |
| new_classifier = torch.nn.Sequential(*list(vgg19.children())[-1][:4]) | |
| vgg19.classifier = new_classifier | |
| vgg19.eval() | |
| return vgg19 | |
| # 定义图像预处理步骤 | |
| transform = transforms.Compose([ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
| ]) | |
| # 处理视频,提取帧并特征,并保存为 .pkl 文件 | |
| def extract_frames_and_features(model_vgg19, video_path): | |
| # 打开视频文件 | |
| cap = cv2.VideoCapture(video_path) | |
| feature_list = [] | |
| frame_count = 0 | |
| buffer = [] | |
| count = 0 | |
| while True: | |
| # 读取视频帧 | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| # 将 BGR 格式的帧转换为 RGB 格式 | |
| frame_count += 1 | |
| if frame_count % 16 == 0: | |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| # 将帧转换为 PIL.Image | |
| pil_frame = Image.fromarray(frame_rgb) | |
| # 将帧进行预处理 | |
| input_tensor = transform(pil_frame) | |
| # 添加一个批次维度 | |
| input_batch = input_tensor.unsqueeze(0) | |
| # 使用 VGG19 特征提取器提取特征 | |
| with torch.no_grad(): | |
| output = model_vgg19(input_batch).data[0] | |
| feature_list.append(output.numpy()) | |
| # 关闭视频文件 | |
| cap.release() | |
| feature_matrix = np.stack(feature_list) | |
| print(feature_matrix.shape) | |
| return feature_matrix | |
| # # 视频文件路径 | |
| # video_path = "/mnt/data10t/dazuoye/GROUP2024-GEN6/FakeSV/code/C3D_Feature_Extractor/raw_video/douyin_6571001202379590925.mp4" | |
| # # 提取视频帧并特征 | |
| # video_features = extract_frames_and_features(video_path) | |
| # # 将特征列表保存为 .pkl 文件 | |
| # # 确保输出目录存在,如果不存在则创建 | |
| # output_dir = 'outputs' | |
| # if not os.path.exists(output_dir): | |
| # os.makedirs(output_dir) | |
| # output_file_path = os.path.join(output_dir, 'video_features.pkl') | |
| # with open(output_file_path, 'wb') as f: | |
| # pickle.dump(video_features, f) | |
| # print("Video features saved to", output_file_path) | |
| def process_video(model_vgg19, video_name, input_folder, output_folder): | |
| # 确保输出目录存在,如果不存在则创建 | |
| if not os.path.exists(output_folder): | |
| os.makedirs(output_folder) | |
| video_path = os.path.join(input_folder, video_name) | |
| video_features = extract_frames_and_features(model_vgg19, video_path) | |
| output_file_path = os.path.join(output_folder, video_name.split('.')[0] + '.pkl') | |
| with open(output_file_path, 'wb') as f: | |
| pickle.dump(video_features, f) | |
| print("Video features saved to", output_file_path) | |
| if __name__ == "__main__": | |
| input_folder = './FakeVD/code/C3D_Feature_Extractor/raw_video' | |
| output_folder = './FakeVD/code/VGG19/outputs' | |
| video_name = 'douyin_6571001202379590925.mp4' | |
| model_vgg19 = load_model_vgg19() | |
| process_video(model_vgg19, video_name, input_folder, output_folder) |