|
|
from transformers import ViTImageProcessor |
|
|
import torch |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
|
|
|
def load_image_vit(image_file, processor): |
|
|
image = Image.open(image_file).convert('RGB') |
|
|
inputs = processor(images=image, return_tensors="pt") |
|
|
return inputs["pixel_values"] |
|
|
|
|
|
def predict_toxicity_vit(model, inputs, device): |
|
|
model.eval() |
|
|
with torch.no_grad(): |
|
|
inputs = inputs.to(device) |
|
|
outputs = model(inputs).logits |
|
|
probabilities = torch.nn.functional.softmax(outputs, dim=1) |
|
|
prediction = torch.argmax(probabilities, dim=1) |
|
|
return prediction.item(), probabilities[0].cpu().numpy() |
|
|
|
|
|
def get_label(prediction): |
|
|
return "Toxic" if prediction == 1 else "Non-Toxic" |