| import pandas as pd | |
| import torch | |
| from PIL import Image | |
| from torchvision import transforms | |
| # Load the model | |
| model = torch.jit.load("SuSy.pt") | |
| # Load patch | |
| patch = Image.open("midjourney-images-example-patch0.png") | |
| # Transform patch to tensor | |
| patch = transforms.PILToTensor()(patch).unsqueeze(0) / 255. | |
| # Predict patch | |
| model.eval() | |
| with torch.no_grad(): | |
| preds = model(patch) | |
| # Print results | |
| classes = ['authentic', 'dalle-3-images', 'diffusiondb', 'midjourney-images', 'midjourney_tti', 'realisticSDXL'] | |
| result = pd.DataFrame(preds.numpy(), columns=classes) | |
| print(result) |