OBA-Research commited on
Commit
68ec25f
·
verified ·
1 Parent(s): 219c40f

Upload folder using huggingface_hub

Browse files
model/config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architecture": "VAAS",
3
+ "version": "v1",
4
+ "alpha": 0.5,
5
+ "input_size": [
6
+ 224,
7
+ 224
8
+ ],
9
+ "px_checkpoint": "px_model.pth",
10
+ "fx_backbone": "google/vit-base-patch16-224",
11
+ "px_backbone": "nvidia/segformer-b1"
12
+ }
model/px_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3c8f0aea456a5175db54de8c8483ddd5b001e816fcac249d3968dcd7549603fb
3
+ size 54798133
model/ref_stats.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:09287fa16965a465e7b71a19d43c9eca95f2a086af4428d47e963ff230da432e
3
+ size 1845
vaas/inference/pipeline.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Dict, Union
3
+
4
+ import torch
5
+ from PIL import Image
6
+ import torchvision.transforms as T
7
+
8
+ from vaas.fx.fx_model import FxViT
9
+ from vaas.px.px_model import PatchConsistencySegformer
10
+ from vaas.fusion.hybrid_score import compute_scores
11
+ from vaas.inference.utils import load_ref_stats, load_px_checkpoint
12
+
13
+ import warnings
14
+ warnings.filterwarnings("ignore")
15
+
16
+ from transformers.utils import logging as hf_logging
17
+ hf_logging.set_verbosity_error()
18
+
19
+ from huggingface_hub import hf_hub_download
20
+
21
+
22
+ class VAASPipeline:
23
+ def __init__(
24
+ self,
25
+ model_px,
26
+ model_fx,
27
+ mu_ref,
28
+ sigma_ref,
29
+ device,
30
+ transform,
31
+ alpha=0.5,
32
+ ):
33
+ self.device = device
34
+
35
+ self.model_px = model_px.to(device)
36
+ self.model_fx = model_fx.to(device)
37
+
38
+ self.mu_ref = (
39
+ mu_ref.to(device) if torch.is_tensor(mu_ref)
40
+ else torch.tensor(mu_ref, device=device)
41
+ )
42
+
43
+ self.sigma_ref = (
44
+ sigma_ref.to(device) if torch.is_tensor(sigma_ref)
45
+ else torch.tensor(sigma_ref, device=device)
46
+ )
47
+
48
+ self.transform = transform
49
+ self.alpha = alpha
50
+
51
+ self.model_px.eval()
52
+ self.model_fx.eval()
53
+
54
+ @classmethod
55
+ def from_checkpoint(
56
+ cls,
57
+ checkpoint_dir: str,
58
+ device: Union[str, torch.device] = "cpu",
59
+ alpha: float = 0.5,
60
+ ):
61
+ if isinstance(device, str):
62
+ device = torch.device(device)
63
+
64
+ model_px = PatchConsistencySegformer()
65
+ model_fx = FxViT()
66
+
67
+ model_fx.eval()
68
+ model_px.eval()
69
+
70
+ load_px_checkpoint(model_px, checkpoint_dir)
71
+ model_px = model_px.to(device)
72
+ model_fx = model_fx.to(device)
73
+
74
+ mu_ref, sigma_ref = load_ref_stats(checkpoint_dir)
75
+
76
+ transform = T.Compose(
77
+ [
78
+ T.Resize((224, 224)),
79
+ T.ToTensor(),
80
+ T.Normalize(
81
+ mean=(0.485, 0.456, 0.406),
82
+ std=(0.229, 0.224, 0.225),
83
+ ),
84
+ ]
85
+ )
86
+
87
+ return cls(
88
+ model_px=model_px,
89
+ model_fx=model_fx,
90
+ mu_ref=mu_ref,
91
+ sigma_ref=sigma_ref,
92
+ device=device,
93
+ transform=transform,
94
+ alpha=alpha,
95
+ )
96
+
97
+
98
+
99
+ @classmethod
100
+ def from_pretrained(
101
+ cls,
102
+ repo_id: str,
103
+ device: str = "cpu",
104
+ alpha: float = 0.5,
105
+ ):
106
+ px_path = hf_hub_download(
107
+ repo_id=repo_id,
108
+ filename="model/px_model.pth",
109
+ )
110
+ ref_path = hf_hub_download(
111
+ repo_id=repo_id,
112
+ filename="model/ref_stats.pth",
113
+ )
114
+
115
+ model_px = PatchConsistencySegformer()
116
+ state = torch.load(px_path, map_location="cpu")
117
+ model_px.load_state_dict(state)
118
+
119
+ ref = torch.load(ref_path, map_location="cpu")
120
+ mu_ref = ref["mu_ref"]
121
+ sigma_ref = ref["sigma_ref"]
122
+
123
+ model_fx = FxViT()
124
+
125
+ transform = T.Compose(
126
+ [
127
+ T.Resize((224, 224)),
128
+ T.ToTensor(),
129
+ T.Normalize(
130
+ mean=(0.485, 0.456, 0.406),
131
+ std=(0.229, 0.224, 0.225),
132
+ ),
133
+ ]
134
+ )
135
+
136
+ return cls(
137
+ model_px=model_px,
138
+ model_fx=model_fx,
139
+ mu_ref=mu_ref,
140
+ sigma_ref=sigma_ref,
141
+ device=device,
142
+ transform=transform,
143
+ alpha=alpha,
144
+ )
145
+
146
+
147
+ @torch.no_grad()
148
+ def __call__(self, image: Union[str, Image.Image]) -> Dict[str, Union[float, "np.ndarray"]]:
149
+ if isinstance(image, str):
150
+ image = Image.open(image).convert("RGB")
151
+
152
+ s_f, s_p, s_h, anomaly_map = compute_scores(
153
+ img=image,
154
+ mask=None,
155
+ model_px=self.model_px,
156
+ vit_model=self.model_fx,
157
+ mu_ref=self.mu_ref,
158
+ sigma_ref=self.sigma_ref,
159
+ transform=self.transform,
160
+ alpha=self.alpha,
161
+ )
162
+
163
+ if torch.is_tensor(anomaly_map):
164
+ anomaly_map = anomaly_map.detach().cpu().numpy()
165
+
166
+
167
+ return {
168
+ "S_F": float(s_f),
169
+ "S_P": float(s_p),
170
+ "S_H": float(s_h),
171
+ "anomaly_map": anomaly_map,
172
+ }
173
+
vaas/inference/utils.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import json
4
+
5
+
6
+ def load_px_checkpoint(model, checkpoint_dir):
7
+ ckpt_path = os.path.join(checkpoint_dir, "best_model_px.pth")
8
+ if not os.path.exists(ckpt_path):
9
+ raise FileNotFoundError(f"Missing checkpoint: {ckpt_path}")
10
+
11
+ state = torch.load(ckpt_path, map_location="cpu")
12
+ model.load_state_dict(state["model_state_dict"])
13
+
14
+
15
+ def load_ref_stats(checkpoint_dir):
16
+ ref_path = os.path.join(checkpoint_dir, "ref_stats.pth")
17
+ if not os.path.exists(ref_path):
18
+ raise FileNotFoundError(f"Missing reference stats: {ref_path}")
19
+
20
+ stats = torch.load(ref_path, map_location="cpu")
21
+ return stats["mu_ref"], stats["sigma_ref"]