phoebehxf
init
aff3c6f
import logging
import shutil
import tempfile
import zipfile
try:
from importlib.resources import files
except:
from importlib_resources import files
from pathlib import Path
import requests
from tqdm import tqdm
from typing import Optional
logger = logging.getLogger(__name__)
_MODELS = {
"ctc": "https://github.com/weigertlab/trackastra-models/releases/download/v0.3.0/ctc.zip",
"general_2d": "https://github.com/weigertlab/trackastra-models/releases/download/v0.3.0/general_2d.zip",
}
def download_and_unzip(url: str, dst: Path):
# TODO make safe and use tempfile lib
if dst.exists():
print(f"{dst} already downloaded, skipping.")
return
# get the name of the zipfile
zip_base = Path(url.split("/")[-1])
with tempfile.TemporaryDirectory() as tmp:
tmp = Path(tmp)
zip_file = tmp / zip_base
# Download the zip file
download(url, zip_file)
# Unzip the file
with zipfile.ZipFile(zip_file, "r") as zip_ref:
zip_ref.extractall(tmp)
shutil.move(tmp / zip_base.stem, dst)
def download(url: str, fname: Path):
resp = requests.get(url, stream=True)
total = int(resp.headers.get("content-length", 0))
# try:
# with (open(str(fname), "wb") as file,
# tqdm(
# desc=str(fname),
# total=total,
# unit="iB",
# unit_scale=True,
# unit_divisor=1024,
# ) as bar,):
# for data in resp.iter_content(chunk_size=1024):
# size = file.write(data)
# bar.update(size)
# except:
with open(str(fname), "wb") as file, tqdm(
desc=str(fname),
total=total,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as bar:
for data in resp.iter_content(chunk_size=1024):
size = file.write(data)
bar.update(size)
def download_pretrained(name: str, download_dir: Optional[Path] = None):
# TODO make safe, introduce versioning
if download_dir is None:
download_dir = files("trackastra").joinpath(".models")
else:
download_dir = Path(download_dir)
download_dir.mkdir(exist_ok=True, parents=True)
try:
url = _MODELS[name]
except KeyError:
raise ValueError(
"Pretrained model `name` is not available. Choose from"
f" {list(_MODELS.keys())}"
)
folder = download_dir / name
download_and_unzip(url=url, dst=folder)
return folder