from transformers import ViTImageProcessor import torch from PIL import Image import numpy as np def load_image_vit(image_file, processor): image = Image.open(image_file).convert('RGB') inputs = processor(images=image, return_tensors="pt") return inputs["pixel_values"] # Shape: [1, 3, 224, 224] def predict_toxicity_vit(model, inputs, device): model.eval() with torch.no_grad(): inputs = inputs.to(device) outputs = model(inputs).logits probabilities = torch.nn.functional.softmax(outputs, dim=1) prediction = torch.argmax(probabilities, dim=1) return prediction.item(), probabilities[0].cpu().numpy() def get_label(prediction): return "Toxic" if prediction == 1 else "Non-Toxic"