Spaces:
Runtime error
Runtime error
Commit
·
de33a84
1
Parent(s):
d5321e7
check valid hub id
Browse files
app.py
CHANGED
|
@@ -10,8 +10,7 @@ import pandas as pd
|
|
| 10 |
import os
|
| 11 |
import backoff
|
| 12 |
from functools import lru_cache
|
| 13 |
-
|
| 14 |
-
import os
|
| 15 |
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
| 16 |
|
| 17 |
|
|
@@ -65,15 +64,20 @@ def return_random_sample(k=27):
|
|
| 65 |
images = dataset[sample]["image"]
|
| 66 |
return [resize_image(image).convert("RGB") for image in images]
|
| 67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
|
| 69 |
def predict_subset(model_id, token):
|
| 70 |
API_URL = f"https://api-inference.huggingface.co/models/{model_id}"
|
| 71 |
headers = {"Authorization": f"Bearer {token}"}
|
| 72 |
-
|
|
|
|
|
|
|
| 73 |
@backoff.on_predicate(backoff.expo, lambda x: x.status_code == 503, max_time=30)
|
| 74 |
def _query(url):
|
| 75 |
r = requests.post(API_URL, headers=headers, data=url)
|
| 76 |
-
print(r)
|
| 77 |
return r
|
| 78 |
|
| 79 |
@lru_cache(maxsize=1000)
|
|
|
|
| 10 |
import os
|
| 11 |
import backoff
|
| 12 |
from functools import lru_cache
|
| 13 |
+
from huggingface_hub import list_models, ModelFilter
|
|
|
|
| 14 |
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
|
| 15 |
|
| 16 |
|
|
|
|
| 64 |
images = dataset[sample]["image"]
|
| 65 |
return [resize_image(image).convert("RGB") for image in images]
|
| 66 |
|
| 67 |
+
@lru_cache()
|
| 68 |
+
def get_valid_hub_image_classification_model_ids():
|
| 69 |
+
models = list_models(limit=None, filter=ModelFilter(task="image-classification"))
|
| 70 |
+
return {model.id for model in models}
|
| 71 |
|
| 72 |
def predict_subset(model_id, token):
|
| 73 |
API_URL = f"https://api-inference.huggingface.co/models/{model_id}"
|
| 74 |
headers = {"Authorization": f"Bearer {token}"}
|
| 75 |
+
valid_model_ids = get_valid_hub_image_classification_model_ids()
|
| 76 |
+
if model_id not in valid_model_ids:
|
| 77 |
+
gr.Error(f"model_id {model_id} is not a valid image classification model id")
|
| 78 |
@backoff.on_predicate(backoff.expo, lambda x: x.status_code == 503, max_time=30)
|
| 79 |
def _query(url):
|
| 80 |
r = requests.post(API_URL, headers=headers, data=url)
|
|
|
|
| 81 |
return r
|
| 82 |
|
| 83 |
@lru_cache(maxsize=1000)
|