refoundd commited on
Commit
cdaa34d
·
verified ·
1 Parent(s): 2861775

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +18 -16
handler.py CHANGED
@@ -140,14 +140,13 @@ class EndpointHandler:
140
  self.labels: LabelData = load_labels_hf(repo_id=repo_id)
141
 
142
  print("Creating data transform...")
143
- self.transform = create_transform(**resolve_data_config(model.pretrained_cfg, model=model))
144
 
145
- with torch.inference_mode():
146
- # move model to GPU, if available
147
- if torch_device.type != "cpu":
148
- self.model = self.model.to(torch_device)
149
 
150
- uri = os.environ.get("MongoDB", "mongodb+srv://jamie:qJiuKQpqhXMHGb74@cluster0.i5ujz.mongodb.net/")
151
  self.client = MongoClient(uri)
152
 
153
  self.db = self.client['nomorecopyright']
@@ -184,16 +183,18 @@ class EndpointHandler:
184
  inputs: Tensor = self.transform(img_input).unsqueeze(0)
185
  # NCHW image RGB to BGR
186
  inputs = inputs[:, [2, 1, 0]]
187
- inputs = inputs.to(torch_device)
188
- print("Running inference...")
189
- outputs = self.model.forward(inputs)
190
- # apply the final activation function (timm doesn't support doing this internally)
191
- outputs = F.sigmoid(outputs)
192
- # move inputs, outputs, and model back to to cpu if we were on GPU
193
- if torch_device.type != "cpu":
194
- inputs = inputs.to("cpu")
195
- outputs = outputs.to("cpu")
196
-
 
 
197
  print("Processing results...")
198
  caption, taglist, ratings, character, general = get_tags(
199
  probs=outputs.squeeze(0),
@@ -203,6 +204,7 @@ class EndpointHandler:
203
  )
204
 
205
  results={**ratings, **character, **general}
 
206
  print(results)
207
 
208
  saveQuery = {"_id": document.get('_id')}
 
140
  self.labels: LabelData = load_labels_hf(repo_id=repo_id)
141
 
142
  print("Creating data transform...")
143
+ self.transform = create_transform(**resolve_data_config(self.model.pretrained_cfg, model=self.model))
144
 
145
+ # move model to GPU, if available
146
+ if torch_device.type != "cpu":
147
+ self.model = self.model.to(torch_device)
 
148
 
149
+ uri = os.environ.get("MongoDB", "")
150
  self.client = MongoClient(uri)
151
 
152
  self.db = self.client['nomorecopyright']
 
183
  inputs: Tensor = self.transform(img_input).unsqueeze(0)
184
  # NCHW image RGB to BGR
185
  inputs = inputs[:, [2, 1, 0]]
186
+ with torch.inference_mode():
187
+ # move model to GPU, if available
188
+ if torch_device.type != "cpu":
189
+ inputs = inputs.to(torch_device)
190
+ print("Running inference...")
191
+ outputs = self.model.forward(inputs)
192
+ # apply the final activation function (timm doesn't support doing this internally)
193
+ outputs = F.sigmoid(outputs)
194
+ # move inputs, outputs, and model back to to cpu if we were on GPU
195
+ if torch_device.type != "cpu":
196
+ inputs = inputs.to("cpu")
197
+ outputs = outputs.to("cpu")
198
  print("Processing results...")
199
  caption, taglist, ratings, character, general = get_tags(
200
  probs=outputs.squeeze(0),
 
204
  )
205
 
206
  results={**ratings, **character, **general}
207
+ results={key: float(value) for key, value in results.items()}
208
  print(results)
209
 
210
  saveQuery = {"_id": document.get('_id')}