Spaces:
Running
on
Zero
Running
on
Zero
| import argparse | |
| import time | |
| import math | |
| import os | |
| import shutil | |
| from joblib import load | |
| import cv2 | |
| import torch | |
| import torch.nn as nn | |
| from torch.utils.data import DataLoader, Dataset | |
| from thop import profile | |
| from torchvision import models, transforms | |
| from extractor.visualise_vit_layer import VitGenerator | |
| from relax_vqa import get_deep_feature, process_video_feature, process_patches, get_frame_patches, flow_to_rgb, merge_fragments, concatenate_features | |
| from extractor.vf_extract import process_video_residual | |
| from model_regression import Mlp, preprocess_data | |
| def fix_state_dict(state_dict): | |
| new_state_dict = {} | |
| for k, v in state_dict.items(): | |
| if k.startswith('module.'): | |
| name = k[7:] | |
| elif k == 'n_averaged': | |
| continue | |
| else: | |
| name = k | |
| new_state_dict[name] = v | |
| return new_state_dict | |
| def preprocess_data(X, y=None, imp=None, scaler=None): | |
| if not isinstance(X, torch.Tensor): | |
| X = torch.tensor(X, device='cuda' if torch.cuda.is_available() else 'cpu') | |
| X = torch.where(torch.isnan(X) | torch.isinf(X), torch.tensor(0.0, device=X.device), X) | |
| if imp is not None or scaler is not None: | |
| X_np = X.cpu().numpy() | |
| if imp is not None: | |
| X_np = imp.transform(X_np) | |
| if scaler is not None: | |
| X_np = scaler.transform(X_np) | |
| X = torch.from_numpy(X_np).to(X.device) | |
| if y is not None and y.size > 0: | |
| if not isinstance(y, torch.Tensor): | |
| y = torch.tensor(y, device=X.device) | |
| y = y.reshape(-1).squeeze() | |
| else: | |
| y = None | |
| return X, y, imp, scaler | |
| def load_model(config, device, input_features=35203): | |
| network_name = 'relaxvqa' | |
| # input_features = X_test_processed.shape[1] | |
| model = Mlp(input_features=input_features, out_features=1, drop_rate=0.2, act_layer=nn.GELU).to(device) | |
| if config['is_finetune']: | |
| model_path = os.path.join(config['save_path'], f"fine_tune_model/{config['video_type']}_{network_name}_{config['select_criteria']}_fine_tuned_model.pth") | |
| else: | |
| model_path = os.path.join(config['save_path'], f"{config['train_data_name']}_{network_name}_{config['select_criteria']}_trained_median_model_param_onLSVQ_TEST.pth") | |
| print("Loading model from:", model_path) | |
| state_dict = torch.load(model_path, map_location=device) | |
| fixed_state_dict = fix_state_dict(state_dict) | |
| try: | |
| model.load_state_dict(fixed_state_dict) | |
| except RuntimeError as e: | |
| print(e) | |
| return model | |
| def evaluate_video_quality(config, resnet50, vit, model_mlp, device): | |
| is_finetune = config['is_finetune'] | |
| save_path = config['save_path'] | |
| video_type = config['video_type'] | |
| video_name = config['video_name'] | |
| framerate = config['framerate'] | |
| sampled_fragment_path = os.path.join("../video_sampled_frame/sampled_frame/", "test_sampled_fragment") | |
| video_path = config.get("video_path") | |
| if video_path is None: | |
| if video_type == 'youtube_ugc': | |
| video_path = f'./ugc_original_videos/{video_name}.mkv' | |
| else: | |
| video_path = f'./ugc_original_videos/{video_name}.mp4' | |
| target_size = 224 | |
| patch_size = 16 | |
| top_n = int((target_size / patch_size) * (target_size / patch_size)) | |
| # sampled video frames | |
| start_time = time.time() | |
| frames, frames_next = process_video_residual(video_type, video_name, framerate, video_path, sampled_fragment_path) | |
| # get ResNet50 layer-stack features and ViT pooling features | |
| all_frame_activations_resnet = [] | |
| all_frame_activations_vit = [] | |
| # get fragments ResNet50 features and ViT features | |
| all_frame_activations_sampled_resnet = [] | |
| all_frame_activations_merged_resnet = [] | |
| all_frame_activations_sampled_vit = [] | |
| all_frame_activations_merged_vit = [] | |
| batch_size = 64 # Define the number of frames to process in parallel | |
| for i in range(0, len(frames_next), batch_size): | |
| batch_frames = frames[i:i + batch_size] | |
| batch_rgb_frames = [cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) for frame in batch_frames] | |
| batch_frames_next = frames_next[i:i + batch_size] | |
| batch_tensors = torch.stack([transforms.ToTensor()(frame) for frame in batch_frames]).to(device) | |
| batch_rgb_tensors = torch.stack([transforms.ToTensor()(frame_rgb) for frame_rgb in batch_rgb_frames]).to(device) | |
| batch_tensors_next = torch.stack([transforms.ToTensor()(frame_next) for frame_next in batch_frames_next]).to(device) | |
| # compute residuals | |
| residuals = torch.abs(batch_tensors_next - batch_tensors) | |
| # calculate optical flows | |
| batch_gray_frames = [cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) for frame in batch_frames] | |
| batch_gray_frames_next = [cv2.cvtColor(frame_next, cv2.COLOR_BGR2GRAY) for frame_next in batch_frames_next] | |
| batch_gray_frames = [frame.cpu().numpy() if isinstance(frame, torch.Tensor) else frame for frame in batch_gray_frames] | |
| batch_gray_frames_next = [frame.cpu().numpy() if isinstance(frame, torch.Tensor) else frame for frame in batch_gray_frames_next] | |
| flows = [cv2.calcOpticalFlowFarneback(batch_gray_frames[j], batch_gray_frames_next[j], None, 0.5, 3, 15, 3, 5, 1.2,0) for j in range(len(batch_gray_frames))] | |
| for j in range(batch_tensors.size(0)): | |
| '''sampled video frames''' | |
| frame_tensor = batch_tensors[j].unsqueeze(0) | |
| frame_rgb_tensor = batch_rgb_tensors[j].unsqueeze(0) | |
| # frame_next_tensor = batch_tensors_next[j].unsqueeze(0) | |
| frame_number = i + j + 1 | |
| # ResNet50 layer-stack features | |
| activations_dict_resnet, _, _ = get_deep_feature('resnet50', video_name, frame_rgb_tensor, frame_number, resnet50, device, 'layerstack') | |
| all_frame_activations_resnet.append(activations_dict_resnet) | |
| # ViT pooling features | |
| activations_dict_vit, _, _ = get_deep_feature('vit', video_name, frame_rgb_tensor, frame_number, vit, device, 'pool') | |
| all_frame_activations_vit.append(activations_dict_vit) | |
| '''residual video frames''' | |
| residual = residuals[j].unsqueeze(0) | |
| flow = flows[j] | |
| original_path = os.path.join(sampled_fragment_path, f'{video_name}_{frame_number}.png') | |
| # Frame Differencing | |
| residual_frag_path, diff_frag, positions = process_patches(original_path, 'frame_diff', residual, patch_size, target_size, top_n) | |
| # Frame fragment | |
| frame_patches = get_frame_patches(frame_tensor, positions, patch_size, target_size) | |
| # Optical Flow | |
| opticalflow_rgb = flow_to_rgb(flow) | |
| opticalflow_rgb_tensor = transforms.ToTensor()(opticalflow_rgb).unsqueeze(0).to(device) | |
| opticalflow_frag_path, flow_frag, _ = process_patches(original_path, 'optical_flow', opticalflow_rgb_tensor, patch_size, target_size, top_n) | |
| merged_frag = merge_fragments(diff_frag, flow_frag) | |
| # fragments ResNet50 features | |
| sampled_frag_activations_resnet, _, _ = get_deep_feature('resnet50', video_name, frame_patches, frame_number, resnet50, device, 'layerstack') | |
| merged_frag_activations_resnet, _, _ = get_deep_feature('resnet50', video_name, merged_frag, frame_number, resnet50, device, 'pool') | |
| all_frame_activations_sampled_resnet.append(sampled_frag_activations_resnet) | |
| all_frame_activations_merged_resnet.append(merged_frag_activations_resnet) | |
| # fragments ViT features | |
| sampled_frag_activations_vit,_, _ = get_deep_feature('vit', video_name, frame_patches, frame_number, vit, device, 'pool') | |
| merged_frag_activations_vit, _, _ = get_deep_feature('vit', video_name, merged_frag, frame_number, vit, device, 'pool') | |
| all_frame_activations_sampled_vit.append(sampled_frag_activations_vit) | |
| all_frame_activations_merged_vit.append(merged_frag_activations_vit) | |
| print(f'video frame number: {len(all_frame_activations_resnet)}') | |
| averaged_frames_resnet = process_video_feature(all_frame_activations_resnet, 'resnet50', 'layerstack') | |
| averaged_frames_vit = process_video_feature(all_frame_activations_vit, 'vit', 'pool') | |
| # print("ResNet50 layer-stacking feature shape:", averaged_frames_resnet.shape) | |
| # print("ViT pooling feature shape:", averaged_frames_vit.shape) | |
| averaged_frames_sampled_resnet = process_video_feature(all_frame_activations_sampled_resnet, 'resnet50', 'layerstack') | |
| averaged_frames_merged_resnet = process_video_feature(all_frame_activations_merged_resnet, 'resnet50', 'pool') | |
| averaged_combined_feature_resnet = concatenate_features(averaged_frames_sampled_resnet, averaged_frames_merged_resnet) | |
| # print("Sampled fragments ResNet50 features shape:", averaged_frames_sampled_resnet.shape) | |
| # print("Merged fragments ResNet50 features shape:", averaged_frames_merged_resnet.shape) | |
| averaged_frames_sampled_vit = process_video_feature(all_frame_activations_sampled_vit, 'vit', 'pool') | |
| averaged_frames_merged_vit = process_video_feature(all_frame_activations_merged_vit, 'vit', 'pool') | |
| averaged_combined_feature_vit = concatenate_features(averaged_frames_sampled_vit, averaged_frames_merged_vit) | |
| # print("Sampled fragments ViT features shape:", averaged_frames_sampled_vit.shape) | |
| # print("Merged fragments ResNet50 features shape:", averaged_frames_merged_vit.shape) | |
| # remove tmp folders | |
| shutil.rmtree(sampled_fragment_path) | |
| # concatenate features | |
| combined_features = torch.cat([torch.mean(averaged_frames_resnet, dim=0), torch.mean(averaged_frames_vit, dim=0), | |
| torch.mean(averaged_combined_feature_resnet, dim=0), torch.mean(averaged_combined_feature_vit, dim=0)], dim=0).view(1, -1) | |
| imputer = load(f'{save_path}/scaler/{video_type}_imputer.pkl') | |
| scaler = load(f'{save_path}/scaler/{video_type}_scaler.pkl') | |
| X_test_processed, _, _, _ = preprocess_data(combined_features, None, imp=imputer, scaler=scaler) | |
| feature_tensor = X_test_processed | |
| # evaluation for test video | |
| model_mlp.eval() | |
| with torch.no_grad(): | |
| with torch.cuda.amp.autocast(): | |
| prediction = model_mlp(feature_tensor) | |
| predicted_score = prediction.item() | |
| # print(f"Raw Predicted Quality Score: {predicted_score}") | |
| run_time = time.time() - start_time | |
| if not is_finetune: | |
| if video_type in ['konvid_1k', 'youtube_ugc']: | |
| scaled_prediction = ((predicted_score - 1) / (99 / 4)) + 1.0 | |
| # print(f"Scaled Predicted Quality Score (1-5): {scaled_prediction}") | |
| return scaled_prediction, run_time | |
| else: | |
| scaled_prediction = predicted_score | |
| return scaled_prediction, run_time | |
| else: | |
| return predicted_score, run_time | |
| def parse_arguments(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('-device', type=str, default='gpu', help='cpu or gpu') | |
| parser.add_argument('-model_name', type=str, default='Mlp', help='Name of the regression model') | |
| parser.add_argument('-select_criteria', type=str, default='byrmse', help='Selection criteria') | |
| parser.add_argument('-train_data_name', type=str, default='lsvq_train', help='Name of the training data') | |
| parser.add_argument('-is_finetune', type=bool, default=False, help='With or without finetune') | |
| parser.add_argument('-save_path', type=str, default='model/', help='Path to save models') | |
| parser.add_argument('-video_type', type=str, default='konvid_1k', help='Type of video') | |
| parser.add_argument('-video_name', type=str, default='5636101558_540p', help='Name of the video') | |
| parser.add_argument('-framerate', type=float, default=24, help='Frame rate of the video') | |
| args = parser.parse_args() | |
| return args | |
| if __name__ == '__main__': | |
| args = parse_arguments() | |
| config = vars(args) | |
| if config['device'] == "gpu": | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| else: | |
| device = torch.device("cpu") | |
| print(f"Running on {'GPU' if device.type == 'cuda' else 'CPU'}") | |
| # load models to device | |
| resnet50 = models.resnet50(pretrained=True).to(device) | |
| vit = VitGenerator('vit_base', 16, device, evaluate=True, random=False, verbose=True) | |
| model_mlp = load_model(config, device) | |
| total_time = 0 | |
| num_runs = 1 | |
| for i in range(num_runs): | |
| quality_prediction, run_time = evaluate_video_quality(config, resnet50, vit, model_mlp, device) | |
| print(f"Run {i + 1} - Time taken: {run_time:.4f} seconds") | |
| total_time += run_time | |
| average_time = total_time / num_runs | |
| print(f"Average running time over {num_runs} runs: {average_time:.4f} seconds") | |
| print("Predicted Quality Score:", quality_prediction) |