crack-classifier-500 / inference.py
Thompson001's picture
Upload inference.py with huggingface_hub
1bc9e9e verified
raw
history blame contribute delete
734 Bytes
import torch
from model import SimpleCNN
from PIL import Image
from torchvision import transforms
# Load model
device = "cuda" if torch.cuda.is_available() else "cpu"
model = SimpleCNN()
state = torch.load("pytorch_model.bin", map_location="cpu")
model.load_state_dict(state)
model.eval()
# Preprocessing
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
labels = ["no_crack", "crack"]
def predict(image: Image.Image):
img = transform(image).unsqueeze(0)
with torch.no_grad():
logits = model(img)
probs = torch.softmax(logits, dim=1)[0]
idx = probs.argmax().item()
return {
"label": labels[idx],
"score": float(probs[idx])
}