Spaces:
Build error
Build error
| import os | |
| import sys | |
| import requests | |
| import shutil | |
| import time | |
| from pathlib import Path | |
| from tqdm.auto import tqdm | |
| class CVATDataset: | |
| def __init__(self, cvat_url, org, task_ids, headers=None, params=None, names=None, dest_folder=None): | |
| """ | |
| Connects to serverless CVAT to download datasets. | |
| Args: | |
| cvat_url (str) : CVAT base URL where the server is loaded. | |
| org (str) : organization we are working with, e.g.: 'bulow' | |
| task_ids (list): list with the task IDs inside CVAT. | |
| params (dict): query parameters. | |
| names (dict): dict where the keys are the task id and values | |
| the names of the local files. | |
| dest_folder (str) : destination folder of the zip files. | |
| Returns: | |
| Content ZIP file containing JSON coco annotations and the images. | |
| """ | |
| self.cvat_url = cvat_url | |
| self.org = org | |
| self.task_ids = task_ids | |
| self.dest_folder = dest_folder | |
| self.names_dict = names | |
| if self.names_dict is not None: | |
| assert all([id_ in self.names_dict.keys() for id_ in self.task_ids]), \ | |
| "The keys in names do not match the task IDs." | |
| self.headers = headers | |
| if self.headers is None: | |
| # FIXME: avoid hardcoded authorization. | |
| self.headers = {"Authorization": "Basic ZGphbmdvOlMwbHNraW4xMjM0IQ=="} | |
| self.params = params | |
| if self.params is None: | |
| self.params = { | |
| "format" : "COCO 1.0", | |
| "action" : "download", | |
| "location": "local", | |
| "org" : self.org | |
| } | |
| def countdown_clock(waiting_time): | |
| t0 = time.monotonic() | |
| while time.monotonic() - t0 < waiting_time: | |
| remaining_time = waiting_time - (time.monotonic() - t0) | |
| mins, secs = divmod(int(remaining_time), 60) | |
| sys.stdout.write("\r") | |
| sys.stdout.write(f"{mins:02d}:{secs:02d}") | |
| sys.stdout.flush() | |
| time.sleep(1) | |
| sys.stdout.write("\n") | |
| def _get_dataset(self, endpoint): | |
| response = requests.get( | |
| endpoint, | |
| headers = self.headers, | |
| params = self.params, | |
| stream = True | |
| ) | |
| return response | |
| def _download_task(self, task_id: int, fname: str): | |
| """ Downloads dataset linked to a task. """ | |
| endpoint = f"{self.cvat_url}/api/tasks/{task_id}/dataset" | |
| r = self._get_dataset(endpoint) | |
| while r.status_code != 200: | |
| if r.status_code == 202: | |
| print(f" Status code {r.status_code}: server processing request") | |
| self.countdown_clock(10) | |
| else: | |
| print(f" Status code {r.status_code}: connection error") | |
| self.countdown_clock(30) | |
| r = self._get_dataset(endpoint) | |
| print(f" Status code {r.status_code}: request is ready") | |
| total_length = int(r.headers.get("Content-Length")) | |
| with tqdm.wrapattr(r.raw, "read", total=total_length, desc="") as raw: | |
| with open(fname, "wb") as file: | |
| shutil.copyfileobj(raw, file) | |
| def download_tasks(self): | |
| """ Download all the tasks passed as input. """ | |
| for task_id in self.task_ids: | |
| name_label = task_id | |
| if self.names_dict is not None: | |
| name_label = self.names_dict[task_id] | |
| fname = f"dataset_{name_label}.zip" | |
| if self.dest_folder is not None: | |
| self.dest_folder = Path(self.dest_folder) | |
| self.dest_folder.mkdir(exist_ok=True, parents=True) | |
| fname = (self.dest_folder / fname).resolve().as_posix() | |
| if os.path.exists(fname): | |
| print(f"File {fname} already exists.") | |
| continue | |
| print(f"\nDownloading task {task_id}, with fname {fname}") | |
| self._download_task(task_id, fname) | |
| # TODO: implement unzip function for the tasks | |