File size: 744 Bytes
dd76a5e 8b2f9fc dd76a5e 8b2f9fc dd76a5e 8b2f9fc dd76a5e 8b2f9fc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
from transformers import ViTImageProcessor
import torch
from PIL import Image
import numpy as np
def load_image_vit(image_file, processor):
image = Image.open(image_file).convert('RGB')
inputs = processor(images=image, return_tensors="pt")
return inputs["pixel_values"] # Shape: [1, 3, 224, 224]
def predict_toxicity_vit(model, inputs, device):
model.eval()
with torch.no_grad():
inputs = inputs.to(device)
outputs = model(inputs).logits
probabilities = torch.nn.functional.softmax(outputs, dim=1)
prediction = torch.argmax(probabilities, dim=1)
return prediction.item(), probabilities[0].cpu().numpy()
def get_label(prediction):
return "Toxic" if prediction == 1 else "Non-Toxic" |