File size: 744 Bytes
dd76a5e
8b2f9fc
 
 
 
dd76a5e
 
 
 
8b2f9fc
dd76a5e
8b2f9fc
 
dd76a5e
 
8b2f9fc
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from transformers import ViTImageProcessor
import torch
from PIL import Image
import numpy as np

def load_image_vit(image_file, processor):
    image = Image.open(image_file).convert('RGB')
    inputs = processor(images=image, return_tensors="pt")
    return inputs["pixel_values"]  # Shape: [1, 3, 224, 224]

def predict_toxicity_vit(model, inputs, device):
    model.eval()
    with torch.no_grad():
        inputs = inputs.to(device)
        outputs = model(inputs).logits
        probabilities = torch.nn.functional.softmax(outputs, dim=1)
        prediction = torch.argmax(probabilities, dim=1)
    return prediction.item(), probabilities[0].cpu().numpy()

def get_label(prediction):
    return "Toxic" if prediction == 1 else "Non-Toxic"