fffffgggg54 commited on
Commit
9755f16
·
verified ·
1 Parent(s): 1f44214

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -3
app.py CHANGED
@@ -81,7 +81,6 @@ class Predictor:
81
  self.text_emb_model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
82
  self.text_emb_model = self.text_emb_model.eval()
83
 
84
- @torch.inference_mode()
85
  def embed_text(self, input_strings):
86
  with torch.no_grad():
87
  # Tokenize the input texts
@@ -118,7 +117,7 @@ class Predictor:
118
  image_features = self.cls_model[0].forward_features(image.unsqueeze(0))
119
  outputs = self.cls_model[0].head(image_features, q = query).sigmoid().float()
120
 
121
- general_tag_list = list(zip(self.tag_names, outputs[0].tolist()))
122
  general_tag_list.sort(key=lambda y: y[1], reverse=True)
123
  general_tag_preds_dict = {}
124
  for tag, prob in general_tag_list[:50]:
@@ -138,7 +137,7 @@ class Predictor:
138
  image,
139
  description,
140
  ):
141
- return self.predict(image, self.embed_text(description), ["embedding"])["embedding"]
142
 
143
 
144
  def main():
 
81
  self.text_emb_model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
82
  self.text_emb_model = self.text_emb_model.eval()
83
 
 
84
  def embed_text(self, input_strings):
85
  with torch.no_grad():
86
  # Tokenize the input texts
 
117
  image_features = self.cls_model[0].forward_features(image.unsqueeze(0))
118
  outputs = self.cls_model[0].head(image_features, q = query).sigmoid().float()
119
 
120
+ general_tag_list = list(zip(tag_names, outputs[0].tolist()))
121
  general_tag_list.sort(key=lambda y: y[1], reverse=True)
122
  general_tag_preds_dict = {}
123
  for tag, prob in general_tag_list[:50]:
 
137
  image,
138
  description,
139
  ):
140
+ return self.predict(image, self.embed_text([description]), ["embedding"])["embedding"]
141
 
142
 
143
  def main():