Spaces:
Runtime error
Runtime error
Yahia battach commited on
Commit ·
016de46
1
Parent(s): 7272ff8
edit app.py
Browse files
app.py
CHANGED
|
@@ -129,6 +129,53 @@ def format_name(taxon, common):
|
|
| 129 |
return f"{taxon} ({common})"
|
| 130 |
|
| 131 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
@torch.no_grad()
|
| 133 |
def open_domain_classification(img, rank: int, return_all=False):
|
| 134 |
"""
|
|
@@ -136,7 +183,6 @@ def open_domain_classification(img, rank: int, return_all=False):
|
|
| 136 |
If targeting a higher rank than species, then this function predicts among all
|
| 137 |
species, then sums up species-level probabilities for the given rank.
|
| 138 |
"""
|
| 139 |
-
|
| 140 |
logger.info(f"Starting open domain classification for rank: {rank}")
|
| 141 |
img = preprocess_img(img).to(device)
|
| 142 |
img_features = model.encode_image(img.unsqueeze(0))
|
|
@@ -148,15 +194,13 @@ def open_domain_classification(img, rank: int, return_all=False):
|
|
| 148 |
if rank + 1 == len(ranks):
|
| 149 |
topk = probs.topk(k)
|
| 150 |
prediction_dict = {
|
| 151 |
-
format_name(*txt_names[i]): prob for i, prob in zip(topk.indices, topk.values)
|
| 152 |
}
|
| 153 |
logger.info(f"Top K predictions: {prediction_dict}")
|
| 154 |
-
|
| 155 |
-
logger.info(f"Top prediction name: {top_prediction_name}")
|
| 156 |
-
sample_img, taxon_url = get_sample(metadata_df, top_prediction_name, rank)
|
| 157 |
if return_all:
|
| 158 |
-
return prediction_dict,
|
| 159 |
-
return prediction_dict
|
| 160 |
|
| 161 |
output = collections.defaultdict(float)
|
| 162 |
for i in torch.nonzero(probs > min_prob).squeeze():
|
|
@@ -165,18 +209,11 @@ def open_domain_classification(img, rank: int, return_all=False):
|
|
| 165 |
topk_names = heapq.nlargest(k, output, key=output.get)
|
| 166 |
prediction_dict = {name: output[name] for name in topk_names}
|
| 167 |
logger.info(f"Top K names for output: {topk_names}")
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
top_prediction_name = topk_names[0]
|
| 171 |
-
logger.info(f"Top prediction name: {top_prediction_name}")
|
| 172 |
-
sample_img, taxon_url = get_sample(metadata_df, top_prediction_name, rank)
|
| 173 |
-
logger.info(f"Sample image and taxon URL: {sample_img}, {taxon_url}")
|
| 174 |
-
|
| 175 |
if return_all:
|
| 176 |
-
return prediction_dict,
|
| 177 |
return prediction_dict
|
| 178 |
|
| 179 |
-
|
| 180 |
def change_output(choice):
|
| 181 |
return gr.Label(num_top_classes=k, label=ranks[choice], show_label=True, value=None)
|
| 182 |
|
|
@@ -310,12 +347,19 @@ if __name__ == "__main__":
|
|
| 310 |
fn=change_output, inputs=rank_dropdown, outputs=[open_domain_output]
|
| 311 |
)
|
| 312 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 313 |
open_domain_btn.click(
|
| 314 |
-
fn=lambda img, rank: open_domain_classification(img, rank, return_all=
|
| 315 |
inputs=[img_input, rank_dropdown],
|
| 316 |
outputs=[open_domain_output],
|
| 317 |
)
|
| 318 |
|
|
|
|
| 319 |
zero_shot_btn.click(
|
| 320 |
fn=zero_shot_classification,
|
| 321 |
inputs=[img_input_zs, classes_txt],
|
|
|
|
| 129 |
return f"{taxon} ({common})"
|
| 130 |
|
| 131 |
|
| 132 |
+
# @torch.no_grad()
|
| 133 |
+
# def open_domain_classification(img, rank: int, return_all=False):
|
| 134 |
+
# """
|
| 135 |
+
# Predicts from the entire tree of life.
|
| 136 |
+
# If targeting a higher rank than species, then this function predicts among all
|
| 137 |
+
# species, then sums up species-level probabilities for the given rank.
|
| 138 |
+
# """
|
| 139 |
+
|
| 140 |
+
# logger.info(f"Starting open domain classification for rank: {rank}")
|
| 141 |
+
# img = preprocess_img(img).to(device)
|
| 142 |
+
# img_features = model.encode_image(img.unsqueeze(0))
|
| 143 |
+
# img_features = F.normalize(img_features, dim=-1)
|
| 144 |
+
|
| 145 |
+
# logits = (model.logit_scale.exp() * img_features @ txt_emb).squeeze()
|
| 146 |
+
# probs = F.softmax(logits, dim=0)
|
| 147 |
+
|
| 148 |
+
# if rank + 1 == len(ranks):
|
| 149 |
+
# topk = probs.topk(k)
|
| 150 |
+
# prediction_dict = {
|
| 151 |
+
# format_name(*txt_names[i]): prob for i, prob in zip(topk.indices, topk.values)
|
| 152 |
+
# }
|
| 153 |
+
# logger.info(f"Top K predictions: {prediction_dict}")
|
| 154 |
+
# top_prediction_name = format_name(*txt_names[topk.indices[0]]).split("(")[0]
|
| 155 |
+
# logger.info(f"Top prediction name: {top_prediction_name}")
|
| 156 |
+
# sample_img, taxon_url = get_sample(metadata_df, top_prediction_name, rank)
|
| 157 |
+
# if return_all:
|
| 158 |
+
# return prediction_dict, sample_img, taxon_url
|
| 159 |
+
# return prediction_dict
|
| 160 |
+
|
| 161 |
+
# output = collections.defaultdict(float)
|
| 162 |
+
# for i in torch.nonzero(probs > min_prob).squeeze():
|
| 163 |
+
# output[" ".join(txt_names[i][0][: rank + 1])] += probs[i]
|
| 164 |
+
|
| 165 |
+
# topk_names = heapq.nlargest(k, output, key=output.get)
|
| 166 |
+
# prediction_dict = {name: output[name] for name in topk_names}
|
| 167 |
+
# logger.info(f"Top K names for output: {topk_names}")
|
| 168 |
+
# logger.info(f"Prediction dictionary: {prediction_dict}")
|
| 169 |
+
|
| 170 |
+
# top_prediction_name = topk_names[0]
|
| 171 |
+
# logger.info(f"Top prediction name: {top_prediction_name}")
|
| 172 |
+
# sample_img, taxon_url = get_sample(metadata_df, top_prediction_name, rank)
|
| 173 |
+
# logger.info(f"Sample image and taxon URL: {sample_img}, {taxon_url}")
|
| 174 |
+
|
| 175 |
+
# if return_all:
|
| 176 |
+
# return prediction_dict, sample_img, taxon_url
|
| 177 |
+
# return prediction_dict
|
| 178 |
+
|
| 179 |
@torch.no_grad()
|
| 180 |
def open_domain_classification(img, rank: int, return_all=False):
|
| 181 |
"""
|
|
|
|
| 183 |
If targeting a higher rank than species, then this function predicts among all
|
| 184 |
species, then sums up species-level probabilities for the given rank.
|
| 185 |
"""
|
|
|
|
| 186 |
logger.info(f"Starting open domain classification for rank: {rank}")
|
| 187 |
img = preprocess_img(img).to(device)
|
| 188 |
img_features = model.encode_image(img.unsqueeze(0))
|
|
|
|
| 194 |
if rank + 1 == len(ranks):
|
| 195 |
topk = probs.topk(k)
|
| 196 |
prediction_dict = {
|
| 197 |
+
format_name(*txt_names[i]): prob.item() for i, prob in zip(topk.indices, topk.values)
|
| 198 |
}
|
| 199 |
logger.info(f"Top K predictions: {prediction_dict}")
|
| 200 |
+
|
|
|
|
|
|
|
| 201 |
if return_all:
|
| 202 |
+
return prediction_dict, None, None # Return dummy None values for unused parts
|
| 203 |
+
return prediction_dict # Only return the dictionary for the Label component
|
| 204 |
|
| 205 |
output = collections.defaultdict(float)
|
| 206 |
for i in torch.nonzero(probs > min_prob).squeeze():
|
|
|
|
| 209 |
topk_names = heapq.nlargest(k, output, key=output.get)
|
| 210 |
prediction_dict = {name: output[name] for name in topk_names}
|
| 211 |
logger.info(f"Top K names for output: {topk_names}")
|
| 212 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
if return_all:
|
| 214 |
+
return prediction_dict, None, None
|
| 215 |
return prediction_dict
|
| 216 |
|
|
|
|
| 217 |
def change_output(choice):
|
| 218 |
return gr.Label(num_top_classes=k, label=ranks[choice], show_label=True, value=None)
|
| 219 |
|
|
|
|
| 347 |
fn=change_output, inputs=rank_dropdown, outputs=[open_domain_output]
|
| 348 |
)
|
| 349 |
|
| 350 |
+
# open_domain_btn.click(
|
| 351 |
+
# fn=lambda img, rank: open_domain_classification(img, rank, return_all=True),
|
| 352 |
+
# inputs=[img_input, rank_dropdown],
|
| 353 |
+
# outputs=[open_domain_output],
|
| 354 |
+
# )
|
| 355 |
+
|
| 356 |
open_domain_btn.click(
|
| 357 |
+
fn=lambda img, rank: open_domain_classification(img, rank, return_all=False),
|
| 358 |
inputs=[img_input, rank_dropdown],
|
| 359 |
outputs=[open_domain_output],
|
| 360 |
)
|
| 361 |
|
| 362 |
+
|
| 363 |
zero_shot_btn.click(
|
| 364 |
fn=zero_shot_classification,
|
| 365 |
inputs=[img_input_zs, classes_txt],
|