crack-api / inference_utils.py
Thompson001's picture
Update inference_utils.py
873a70d verified
raw
history blame
1.27 kB
import cv2
import torch
import numpy as np
from PIL import Image
from torchvision import transforms
from models.deepcrack_model import DeepCrackModel
from cv2_utils import getContours
def create_model(opt, cp_path='pretrained_net_G.pth'):
model = DeepCrackModel(opt)
checkpoint = torch.load(cp_path, map_location="cpu")
if hasattr(model.netG, "module"):
model.netG.module.load_state_dict(checkpoint, strict=False)
else:
model.netG.load_state_dict(checkpoint, strict=False)
model.eval()
return model
def preprocess(img: Image.Image):
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
return transform(img).unsqueeze(0)
def inference(model, img: Image.Image):
tensor = preprocess(img)
model.set_input({"image": tensor, "label": torch.zeros_like(tensor), "A_paths": ""})
model.test()
visuals = model.get_current_visuals()
fused = visuals["fused"].detach().cpu().numpy()[0]
confidence = float(fused.max())
fused_img = (fused * 255).astype("uint8")
contour_img = getContours(fused_img, np.array(img), img.size[1], img.size[0], "px", confidence)
return contour_img, confidence