JanviMl commited on
Commit
8b2f9fc
·
verified ·
1 Parent(s): 5787f5d

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +27 -0
utils.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torchvision import transforms
3
+ from PIL import Image
4
+ import numpy as np
5
+
6
+ def load_image(image_path):
7
+ transform = transforms.Compose([
8
+ transforms.Resize((224, 224)),
9
+ transforms.ToTensor(),
10
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
11
+ std=[0.229, 0.224, 0.225])
12
+ ])
13
+
14
+ image = Image.open(image_path).convert('RGB')
15
+ return transform(image).unsqueeze(0)
16
+
17
+ def predict_toxicity(model, image_tensor, device):
18
+ model.eval()
19
+ with torch.no_grad():
20
+ image_tensor = image_tensor.to(device)
21
+ outputs = model(image_tensor)
22
+ probabilities = torch.nn.functional.softmax(outputs, dim=1)
23
+ prediction = torch.argmax(probabilities, dim=1)
24
+ return prediction.item(), probabilities[0].cpu().numpy()
25
+
26
+ def get_label(prediction):
27
+ return "Toxic" if prediction == 1 else "Non-Toxic"