Update app.py
Browse files
app.py
CHANGED
|
@@ -114,8 +114,8 @@ class Predictor:
|
|
| 114 |
def predict(
|
| 115 |
self,
|
| 116 |
image,
|
| 117 |
-
query
|
| 118 |
-
tag_names
|
| 119 |
):
|
| 120 |
image = self.prepare_image(image)
|
| 121 |
|
|
@@ -130,12 +130,19 @@ class Predictor:
|
|
| 130 |
|
| 131 |
return general_tag_preds_dict
|
| 132 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
def predict_new_tag(
|
| 134 |
self,
|
| 135 |
image,
|
| 136 |
query,
|
| 137 |
):
|
| 138 |
-
return self.predict(image, query
|
| 139 |
|
| 140 |
|
| 141 |
def main():
|
|
@@ -189,7 +196,7 @@ def main():
|
|
| 189 |
)
|
| 190 |
|
| 191 |
submit.click(
|
| 192 |
-
predictor.
|
| 193 |
inputs=[
|
| 194 |
image,
|
| 195 |
],
|
|
|
|
| 114 |
def predict(
|
| 115 |
self,
|
| 116 |
image,
|
| 117 |
+
query,
|
| 118 |
+
tag_names,
|
| 119 |
):
|
| 120 |
image = self.prepare_image(image)
|
| 121 |
|
|
|
|
| 130 |
|
| 131 |
return general_tag_preds_dict
|
| 132 |
|
| 133 |
+
def predict_seen_tags(
|
| 134 |
+
self,
|
| 135 |
+
image,
|
| 136 |
+
):
|
| 137 |
+
return self.predict(image, self.class_embed, self.tag_names)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
def predict_new_tag(
|
| 141 |
self,
|
| 142 |
image,
|
| 143 |
query,
|
| 144 |
):
|
| 145 |
+
return self.predict(image, query, ["embedding"])["embedding"]
|
| 146 |
|
| 147 |
|
| 148 |
def main():
|
|
|
|
| 196 |
)
|
| 197 |
|
| 198 |
submit.click(
|
| 199 |
+
predictor.predict_seen_tags,
|
| 200 |
inputs=[
|
| 201 |
image,
|
| 202 |
],
|