JanviMl commited on
Commit
dd76a5e
·
verified ·
1 Parent(s): f07da63

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +8 -13
utils.py CHANGED
@@ -1,23 +1,18 @@
 
1
  import torch
2
- from torchvision import transforms
3
  from PIL import Image
4
  import numpy as np
5
 
6
- def load_image(image_path):
7
- transform = transforms.Compose([
8
- transforms.Resize((224, 224)),
9
- transforms.ToTensor(),
10
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
11
- ])
12
-
13
- image = Image.open(image_path).convert('RGB')
14
- return transform(image).unsqueeze(0)
15
 
16
- def predict_toxicity(model, image_tensor, device):
17
  model.eval()
18
  with torch.no_grad():
19
- image_tensor = image_tensor.to(device)
20
- outputs = model(image_tensor)
21
  probabilities = torch.nn.functional.softmax(outputs, dim=1)
22
  prediction = torch.argmax(probabilities, dim=1)
23
  return prediction.item(), probabilities[0].cpu().numpy()
 
1
+ from transformers import ViTImageProcessor
2
  import torch
 
3
  from PIL import Image
4
  import numpy as np
5
 
6
+ def load_image_vit(image_file, processor):
7
+ image = Image.open(image_file).convert('RGB')
8
+ inputs = processor(images=image, return_tensors="pt")
9
+ return inputs["pixel_values"] # Shape: [1, 3, 224, 224]
 
 
 
 
 
10
 
11
+ def predict_toxicity_vit(model, inputs, device):
12
  model.eval()
13
  with torch.no_grad():
14
+ inputs = inputs.to(device)
15
+ outputs = model(inputs).logits
16
  probabilities = torch.nn.functional.softmax(outputs, dim=1)
17
  prediction = torch.argmax(probabilities, dim=1)
18
  return prediction.item(), probabilities[0].cpu().numpy()