Thompson001 commited on
Commit
9f7f6b4
·
verified ·
1 Parent(s): 054e48e

Delete inference_utils.py

Browse files
Files changed (1) hide show
  1. inference_utils.py +0 -44
inference_utils.py DELETED
@@ -1,44 +0,0 @@
1
- import cv2
2
- import torch
3
- import numpy as np
4
- from PIL import Image
5
- from torchvision import transforms
6
- from models.deepcrack_model import DeepCrackModel
7
- from cv2_utils import getContours
8
-
9
- def create_model(opt, cp_path='pretrained_net_G.pth'):
10
- model = DeepCrackModel(opt)
11
- checkpoint = torch.load(cp_path, map_location="cpu")
12
-
13
- if hasattr(model.netG, "module"):
14
- model.netG.module.load_state_dict(checkpoint, strict=False)
15
- else:
16
- model.netG.load_state_dict(checkpoint, strict=False)
17
-
18
- model.eval()
19
- return model
20
-
21
-
22
- def preprocess(img: Image.Image):
23
- transform = transforms.Compose([
24
- transforms.Resize((256, 256)),
25
- transforms.ToTensor(),
26
- transforms.Normalize((0.5,), (0.5,))
27
- ])
28
- return transform(img).unsqueeze(0)
29
-
30
-
31
- def inference(model, img: Image.Image):
32
- tensor = preprocess(img)
33
-
34
- model.set_input({"image": tensor, "label": torch.zeros_like(tensor), "A_paths": ""})
35
- model.test()
36
-
37
- visuals = model.get_current_visuals()
38
- fused = visuals["fused"].detach().cpu().numpy()[0]
39
- confidence = float(fused.max())
40
-
41
- fused_img = (fused * 255).astype("uint8")
42
- contour_img = getContours(fused_img, np.array(img), img.size[1], img.size[0], "px", confidence)
43
-
44
- return contour_img, confidence