Update handler.py
Browse files- handler.py +0 -6
handler.py
CHANGED
|
@@ -131,15 +131,12 @@ class EndpointHandler:
|
|
| 131 |
self.opts = ScriptOptions
|
| 132 |
repo_id = MODEL_REPO_MAP.get(self.opts.model)
|
| 133 |
|
| 134 |
-
print(f"Loading model '{self.opts.model}' from '{repo_id}'...")
|
| 135 |
self.model: nn.Module = timm.create_model("hf-hub:" + repo_id).eval()
|
| 136 |
state_dict = timm.models.load_state_dict_from_hf(repo_id)
|
| 137 |
self.model.load_state_dict(state_dict)
|
| 138 |
|
| 139 |
-
print("Loading tag list...")
|
| 140 |
self.labels: LabelData = load_labels_hf(repo_id=repo_id)
|
| 141 |
|
| 142 |
-
print("Creating data transform...")
|
| 143 |
self.transform = create_transform(**resolve_data_config(self.model.pretrained_cfg, model=self.model))
|
| 144 |
|
| 145 |
# move model to GPU, if available
|
|
@@ -173,7 +170,6 @@ class EndpointHandler:
|
|
| 173 |
start_time=time.time()
|
| 174 |
for document in data:
|
| 175 |
image=load_image(document.get('createdImage', 'https://nomorecopyright.com/default.jpg'))
|
| 176 |
-
print("Loading image and preprocessing...")
|
| 177 |
# get image
|
| 178 |
# ensure image is RGB
|
| 179 |
img_input = pil_ensure_rgb(image)
|
|
@@ -187,7 +183,6 @@ class EndpointHandler:
|
|
| 187 |
# move model to GPU, if available
|
| 188 |
if torch_device.type != "cpu":
|
| 189 |
inputs = inputs.to(torch_device)
|
| 190 |
-
print("Running inference...")
|
| 191 |
outputs = self.model.forward(inputs)
|
| 192 |
# apply the final activation function (timm doesn't support doing this internally)
|
| 193 |
outputs = F.sigmoid(outputs)
|
|
@@ -195,7 +190,6 @@ class EndpointHandler:
|
|
| 195 |
if torch_device.type != "cpu":
|
| 196 |
inputs = inputs.to("cpu")
|
| 197 |
outputs = outputs.to("cpu")
|
| 198 |
-
print("Processing results...")
|
| 199 |
caption, taglist, ratings, character, general = get_tags(
|
| 200 |
probs=outputs.squeeze(0),
|
| 201 |
labels=self.labels,
|
|
|
|
| 131 |
self.opts = ScriptOptions
|
| 132 |
repo_id = MODEL_REPO_MAP.get(self.opts.model)
|
| 133 |
|
|
|
|
| 134 |
self.model: nn.Module = timm.create_model("hf-hub:" + repo_id).eval()
|
| 135 |
state_dict = timm.models.load_state_dict_from_hf(repo_id)
|
| 136 |
self.model.load_state_dict(state_dict)
|
| 137 |
|
|
|
|
| 138 |
self.labels: LabelData = load_labels_hf(repo_id=repo_id)
|
| 139 |
|
|
|
|
| 140 |
self.transform = create_transform(**resolve_data_config(self.model.pretrained_cfg, model=self.model))
|
| 141 |
|
| 142 |
# move model to GPU, if available
|
|
|
|
| 170 |
start_time=time.time()
|
| 171 |
for document in data:
|
| 172 |
image=load_image(document.get('createdImage', 'https://nomorecopyright.com/default.jpg'))
|
|
|
|
| 173 |
# get image
|
| 174 |
# ensure image is RGB
|
| 175 |
img_input = pil_ensure_rgb(image)
|
|
|
|
| 183 |
# move model to GPU, if available
|
| 184 |
if torch_device.type != "cpu":
|
| 185 |
inputs = inputs.to(torch_device)
|
|
|
|
| 186 |
outputs = self.model.forward(inputs)
|
| 187 |
# apply the final activation function (timm doesn't support doing this internally)
|
| 188 |
outputs = F.sigmoid(outputs)
|
|
|
|
| 190 |
if torch_device.type != "cpu":
|
| 191 |
inputs = inputs.to("cpu")
|
| 192 |
outputs = outputs.to("cpu")
|
|
|
|
| 193 |
caption, taglist, ratings, character, general = get_tags(
|
| 194 |
probs=outputs.squeeze(0),
|
| 195 |
labels=self.labels,
|