Spaces:
Sleeping
Sleeping
| import logging | |
| import shutil | |
| import tempfile | |
| import zipfile | |
| try: | |
| from importlib.resources import files | |
| except: | |
| from importlib_resources import files | |
| from pathlib import Path | |
| import requests | |
| from tqdm import tqdm | |
| from typing import Optional | |
| logger = logging.getLogger(__name__) | |
| _MODELS = { | |
| "ctc": "https://github.com/weigertlab/trackastra-models/releases/download/v0.3.0/ctc.zip", | |
| "general_2d": "https://github.com/weigertlab/trackastra-models/releases/download/v0.3.0/general_2d.zip", | |
| } | |
| def download_and_unzip(url: str, dst: Path): | |
| # TODO make safe and use tempfile lib | |
| if dst.exists(): | |
| print(f"{dst} already downloaded, skipping.") | |
| return | |
| # get the name of the zipfile | |
| zip_base = Path(url.split("/")[-1]) | |
| with tempfile.TemporaryDirectory() as tmp: | |
| tmp = Path(tmp) | |
| zip_file = tmp / zip_base | |
| # Download the zip file | |
| download(url, zip_file) | |
| # Unzip the file | |
| with zipfile.ZipFile(zip_file, "r") as zip_ref: | |
| zip_ref.extractall(tmp) | |
| shutil.move(tmp / zip_base.stem, dst) | |
| def download(url: str, fname: Path): | |
| resp = requests.get(url, stream=True) | |
| total = int(resp.headers.get("content-length", 0)) | |
| # try: | |
| # with (open(str(fname), "wb") as file, | |
| # tqdm( | |
| # desc=str(fname), | |
| # total=total, | |
| # unit="iB", | |
| # unit_scale=True, | |
| # unit_divisor=1024, | |
| # ) as bar,): | |
| # for data in resp.iter_content(chunk_size=1024): | |
| # size = file.write(data) | |
| # bar.update(size) | |
| # except: | |
| with open(str(fname), "wb") as file, tqdm( | |
| desc=str(fname), | |
| total=total, | |
| unit="iB", | |
| unit_scale=True, | |
| unit_divisor=1024, | |
| ) as bar: | |
| for data in resp.iter_content(chunk_size=1024): | |
| size = file.write(data) | |
| bar.update(size) | |
| def download_pretrained(name: str, download_dir: Optional[Path] = None): | |
| # TODO make safe, introduce versioning | |
| if download_dir is None: | |
| download_dir = files("trackastra").joinpath(".models") | |
| else: | |
| download_dir = Path(download_dir) | |
| download_dir.mkdir(exist_ok=True, parents=True) | |
| try: | |
| url = _MODELS[name] | |
| except KeyError: | |
| raise ValueError( | |
| "Pretrained model `name` is not available. Choose from" | |
| f" {list(_MODELS.keys())}" | |
| ) | |
| folder = download_dir / name | |
| download_and_unzip(url=url, dst=folder) | |
| return folder | |