import torch from model import SimpleCNN from PIL import Image from torchvision import transforms # Load model device = "cuda" if torch.cuda.is_available() else "cpu" model = SimpleCNN() state = torch.load("pytorch_model.bin", map_location="cpu") model.load_state_dict(state) model.eval() # Preprocessing transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), ]) labels = ["no_crack", "crack"] def predict(image: Image.Image): img = transform(image).unsqueeze(0) with torch.no_grad(): logits = model(img) probs = torch.softmax(logits, dim=1)[0] idx = probs.argmax().item() return { "label": labels[idx], "score": float(probs[idx]) }