Spaces:
Runtime error
Runtime error
Commit
·
a776cb5
1
Parent(s):
de33a84
raise error if model id not valid
Browse files
app.py
CHANGED
|
@@ -11,6 +11,7 @@ import os
|
|
| 11 |
import backoff
|
| 12 |
from functools import lru_cache
|
| 13 |
from huggingface_hub import list_models, ModelFilter
|
|
|
|
| 14 |
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
| 15 |
|
| 16 |
|
|
@@ -64,17 +65,23 @@ def return_random_sample(k=27):
|
|
| 64 |
images = dataset[sample]["image"]
|
| 65 |
return [resize_image(image).convert("RGB") for image in images]
|
| 66 |
|
|
|
|
| 67 |
@lru_cache()
|
| 68 |
def get_valid_hub_image_classification_model_ids():
|
| 69 |
models = list_models(limit=None, filter=ModelFilter(task="image-classification"))
|
| 70 |
return {model.id for model in models}
|
| 71 |
|
|
|
|
| 72 |
def predict_subset(model_id, token):
|
| 73 |
-
API_URL = f"https://api-inference.huggingface.co/models/{model_id}"
|
| 74 |
-
headers = {"Authorization": f"Bearer {token}"}
|
| 75 |
valid_model_ids = get_valid_hub_image_classification_model_ids()
|
| 76 |
if model_id not in valid_model_ids:
|
| 77 |
-
gr.Error(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
@backoff.on_predicate(backoff.expo, lambda x: x.status_code == 503, max_time=30)
|
| 79 |
def _query(url):
|
| 80 |
r = requests.post(API_URL, headers=headers, data=url)
|
|
|
|
| 11 |
import backoff
|
| 12 |
from functools import lru_cache
|
| 13 |
from huggingface_hub import list_models, ModelFilter
|
| 14 |
+
|
| 15 |
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
| 16 |
|
| 17 |
|
|
|
|
| 65 |
images = dataset[sample]["image"]
|
| 66 |
return [resize_image(image).convert("RGB") for image in images]
|
| 67 |
|
| 68 |
+
|
| 69 |
@lru_cache()
|
| 70 |
def get_valid_hub_image_classification_model_ids():
|
| 71 |
models = list_models(limit=None, filter=ModelFilter(task="image-classification"))
|
| 72 |
return {model.id for model in models}
|
| 73 |
|
| 74 |
+
|
| 75 |
def predict_subset(model_id, token):
|
|
|
|
|
|
|
| 76 |
valid_model_ids = get_valid_hub_image_classification_model_ids()
|
| 77 |
if model_id not in valid_model_ids:
|
| 78 |
+
raise gr.Error(
|
| 79 |
+
f"model_id {model_id} is not a valid image classification model id"
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
API_URL = f"https://api-inference.huggingface.co/models/{model_id}"
|
| 83 |
+
headers = {"Authorization": f"Bearer {token}"}
|
| 84 |
+
|
| 85 |
@backoff.on_predicate(backoff.expo, lambda x: x.status_code == 503, max_time=30)
|
| 86 |
def _query(url):
|
| 87 |
r = requests.post(API_URL, headers=headers, data=url)
|