Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| import asyncio | |
| import io | |
| import json | |
| import os | |
| import httpx | |
| from huggingface_hub import HfApi, HfFileSystem, ModelCard, hf_hub_url | |
| from huggingface_hub.utils import build_hf_headers | |
| import src.constants as constants | |
| class Client: | |
| def __init__(self): | |
| self.client = httpx.AsyncClient(follow_redirects=True) | |
| async def _get(self, url, headers=None, params=None): | |
| r = await self.client.get(url, headers=headers, params=params) | |
| r.raise_for_status() | |
| return r | |
| async def get(self, url, headers=None, params=None): | |
| try: | |
| r = await self._get(url, headers=headers, params=params) | |
| except httpx.ReadTimeout: | |
| return await self.retry(self._get, url, headers=headers, params=params) | |
| except httpx.HTTPError: | |
| return | |
| return r | |
| async def retry(self, func, url, max_retries=4, max_wait_time=8, wait_time=1, **kwargs): | |
| for _ in range(max_retries): | |
| try: | |
| await asyncio.sleep(wait_time) | |
| return await func(url, **kwargs) | |
| except httpx.ReadTimeout: | |
| wait_time = wait_time * 2 | |
| if wait_time > max_wait_time: | |
| print("HTTP Timeout: max retries exceeded with url:", url) | |
| return | |
| api = HfApi() | |
| client = Client() | |
| fs = HfFileSystem() | |
| def glob(path): | |
| paths = fs.glob(path) | |
| return paths | |
| async def load_json_file(path): | |
| url = to_url(path) | |
| r = await client.get(url) | |
| if r is None: | |
| return | |
| return r.json() | |
| async def load_jsonlines_file(path): | |
| url = to_url(path) | |
| r = await client.get(url, headers=build_hf_headers()) | |
| if r is None: | |
| return | |
| f = io.StringIO(r.text) | |
| return [json.loads(line) for line in f] | |
| def to_url(path): | |
| *repo_type, org_name, ds_name, filename = path.split("/", 3) | |
| repo_type = repo_type[0][:-1] if repo_type else None | |
| return hf_hub_url(repo_id=f"{org_name}/{ds_name}", filename=filename, repo_type=repo_type) | |
| async def load_model_card(model_id): | |
| url = to_url(f"{model_id}/README.md") | |
| r = await client.get(url) | |
| if r is None: | |
| return | |
| return ModelCard(r.text, ignore_metadata_errors=True) | |
| async def list_models(filtering=None): | |
| params = {} | |
| if filtering: | |
| params["filter"] = filtering | |
| r = await client.get(f"{constants.HF_API_URL}/models", params=params) | |
| if r is None: | |
| return | |
| return r.json() | |
| def restart_space(): | |
| space_id = os.getenv("SPACE_ID") | |
| if space_id: | |
| api.restart_space(repo_id=space_id, token=os.getenv("HF_TOKEN")) | |