File size: 955 Bytes
fa5803b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
# utils.py
import torch
from torchvision import transforms
from PIL import Image
IMG_SIZE = 224 # Or your desired size
class_names = ['battery', 'biological', 'cardboard', 'clothes', 'glass',
'metal', 'paper', 'platic', 'shoes', 'trash']
# Transformation same as your test transform
test_transform = transforms.Compose([
transforms.Resize((IMG_SIZE, IMG_SIZE)),
transforms.ToTensor(),
])
def load_model(weights_path, device):
model = torch.load(weights_path, map_location=device, weights_only=False)
model.to(device)
model.eval()
return model
def predict_image(model, image, device):
image = image.convert("RGB")
input_tensor = test_transform(image).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(input_tensor)
_, predicted = torch.max(outputs, 1)
class_index = predicted.item()
return class_names[class_index]
|