Create pytorch
Browse filesimport torch
from torchvision import transforms
from PIL import Image
# Load trained model
model = torch.load("model.pth", map_location="cpu")
model.eval()
# Image preprocessing
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
# Load image
image = Image.open("test_image.jpg").convert("RGB")
input_tensor = transform(image).unsqueeze(0)
# Predict
with torch.no_grad():
outputs = model(input_tensor)
predicted_class = torch.argmax(outputs, dim=1)
print("Predicted class index:", predicted_class.item())