| import torch |
| import torch.nn as nn |
| from torch.utils.data import Dataset, DataLoader |
| from torchvision import transforms |
| from transformers import ViTForImageClassification |
| import os |
| import pandas as pd |
| from sklearn.model_selection import train_test_split |
| from sklearn.metrics import accuracy_score, precision_score, confusion_matrix, f1_score, average_precision_score |
| import matplotlib.pyplot as plt |
| import seaborn as sns |
| from sklearn.metrics import recall_score |
| from vit_model_traning import labeling, CustomDataset |
|
|
| |
| def display_video(video_url): |
| return f''' |
| <div id="video-container" style="display: none;"> |
| <video width="640" height="480" controls autoplay> |
| <source src="{video_url}" type="video/mp4"> |
| Your browser does not support the video tag. |
| </video> |
| </div> |
| <script> |
| document.getElementById('video-container').style.display = 'block'; |
| </script> |
| ''' |
|
|
| def shuffle_and_split_data(dataframe, test_size=0.2, random_state=59): |
| shuffled_df = dataframe.sample(frac=1, random_state=random_state).reset_index(drop=True) |
| train_df, val_df = train_test_split(shuffled_df, test_size=test_size, random_state=random_state) |
| return train_df, val_df |
|
|
| if __name__ == "__main__": |
| |
| device = torch.device('cuda') |
|
|
| |
| model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224').to(device) |
|
|
| model.classifier = nn.Linear(model.config.hidden_size, 2).to(device) |
|
|
| |
| preprocess = transforms.Compose([ |
| transforms.Resize((224, 224)), |
| transforms.ToTensor() |
| ]) |
|
|
| |
| test_real_folder = 'datasets/test_set/real/' |
| test_fake_folder = 'datasets/test_set/fake/' |
| |
| test_set = labeling(test_real_folder, test_fake_folder) |
| test_dataset = CustomDataset(test_set, transform=preprocess) |
| test_loader = DataLoader(test_dataset, batch_size=32) |
|
|
| |
| model.load_state_dict(torch.load('trained_model.pth')) |
|
|
| |
| video_url = '"C:\Users\litav\Downloads\0001-0120.mp4"' |
| video_html = display_video(video_url) |
|
|
| |
| print(video_html) |
|
|
| |
| model.eval() |
| true_labels = [] |
| predicted_labels = [] |
|
|
| |
| |
|
|
| with torch.no_grad(): |
| for images, labels in test_loader: |
| images, labels = images.to(device), labels.to(device) |
|
|
| |
| print(video_html) |
|
|
| outputs = model(images) |
| logits = outputs.logits |
| _, predicted = torch.max(logits, 1) |
| true_labels.extend(labels.cpu().numpy()) |
| predicted_labels.extend(predicted.cpu().numpy()) |
|
|
| |
| accuracy = accuracy_score(true_labels, predicted_labels) |
| precision = precision_score(true_labels, predicted_labels) |
| cm = confusion_matrix(true_labels, predicted_labels) |
| f1 = f1_score(true_labels, predicted_labels) |
| ap = average_precision_score(true_labels, predicted_labels) |
| recall = recall_score(true_labels, predicted_labels) |
|
|
| print(f"Test Accuracy: {accuracy:.2%}") |
| print(f"Precision: {precision:.2%}") |
| print(f"F1 Score: {f1:.2%}") |
| print(f"Average Precision: {ap:.2%}") |
| print(f"Recall: {recall:.2%}") |
|
|
| |
| plt.figure(figsize=(8, 6)) |
| sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False) |
| plt.xlabel('Predicted Labels') |
| plt.ylabel('True Labels') |
| plt.title('Confusion Matrix') |
| plt.show() |
|
|