Update handler.py
Browse files- handler.py +18 -16
handler.py
CHANGED
|
@@ -140,14 +140,13 @@ class EndpointHandler:
|
|
| 140 |
self.labels: LabelData = load_labels_hf(repo_id=repo_id)
|
| 141 |
|
| 142 |
print("Creating data transform...")
|
| 143 |
-
self.transform = create_transform(**resolve_data_config(model.pretrained_cfg, model=model))
|
| 144 |
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
self.model = self.model.to(torch_device)
|
| 149 |
|
| 150 |
-
uri = os.environ.get("MongoDB", "
|
| 151 |
self.client = MongoClient(uri)
|
| 152 |
|
| 153 |
self.db = self.client['nomorecopyright']
|
|
@@ -184,16 +183,18 @@ class EndpointHandler:
|
|
| 184 |
inputs: Tensor = self.transform(img_input).unsqueeze(0)
|
| 185 |
# NCHW image RGB to BGR
|
| 186 |
inputs = inputs[:, [2, 1, 0]]
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
outputs
|
| 196 |
-
|
|
|
|
|
|
|
| 197 |
print("Processing results...")
|
| 198 |
caption, taglist, ratings, character, general = get_tags(
|
| 199 |
probs=outputs.squeeze(0),
|
|
@@ -203,6 +204,7 @@ class EndpointHandler:
|
|
| 203 |
)
|
| 204 |
|
| 205 |
results={**ratings, **character, **general}
|
|
|
|
| 206 |
print(results)
|
| 207 |
|
| 208 |
saveQuery = {"_id": document.get('_id')}
|
|
|
|
| 140 |
self.labels: LabelData = load_labels_hf(repo_id=repo_id)
|
| 141 |
|
| 142 |
print("Creating data transform...")
|
| 143 |
+
self.transform = create_transform(**resolve_data_config(self.model.pretrained_cfg, model=self.model))
|
| 144 |
|
| 145 |
+
# move model to GPU, if available
|
| 146 |
+
if torch_device.type != "cpu":
|
| 147 |
+
self.model = self.model.to(torch_device)
|
|
|
|
| 148 |
|
| 149 |
+
uri = os.environ.get("MongoDB", "")
|
| 150 |
self.client = MongoClient(uri)
|
| 151 |
|
| 152 |
self.db = self.client['nomorecopyright']
|
|
|
|
| 183 |
inputs: Tensor = self.transform(img_input).unsqueeze(0)
|
| 184 |
# NCHW image RGB to BGR
|
| 185 |
inputs = inputs[:, [2, 1, 0]]
|
| 186 |
+
with torch.inference_mode():
|
| 187 |
+
# move model to GPU, if available
|
| 188 |
+
if torch_device.type != "cpu":
|
| 189 |
+
inputs = inputs.to(torch_device)
|
| 190 |
+
print("Running inference...")
|
| 191 |
+
outputs = self.model.forward(inputs)
|
| 192 |
+
# apply the final activation function (timm doesn't support doing this internally)
|
| 193 |
+
outputs = F.sigmoid(outputs)
|
| 194 |
+
# move inputs, outputs, and model back to to cpu if we were on GPU
|
| 195 |
+
if torch_device.type != "cpu":
|
| 196 |
+
inputs = inputs.to("cpu")
|
| 197 |
+
outputs = outputs.to("cpu")
|
| 198 |
print("Processing results...")
|
| 199 |
caption, taglist, ratings, character, general = get_tags(
|
| 200 |
probs=outputs.squeeze(0),
|
|
|
|
| 204 |
)
|
| 205 |
|
| 206 |
results={**ratings, **character, **general}
|
| 207 |
+
results={key: float(value) for key, value in results.items()}
|
| 208 |
print(results)
|
| 209 |
|
| 210 |
saveQuery = {"_id": document.get('_id')}
|