| import os |
| import json |
| import shutil |
| from pathlib import Path |
| from typing import Dict |
|
|
| from PIL import ImageFile |
| ImageFile.LOAD_TRUNCATED_IMAGES = True |
|
|
| LOCAL_DATADIR = None |
|
|
| def setup(local_dir='./data/usm-training-data/data'): |
| |
| |
| tmp_datadir = Path('/tmp/data/data') |
| local_test_datadir = Path('./data/usm-test-data-x/data') |
| local_val_datadir = Path(local_dir) |
| |
| os.system('pwd') |
| os.system('ls -lahtr .') |
| |
| if tmp_datadir.exists() and not local_test_datadir.exists(): |
| global LOCAL_DATADIR |
| LOCAL_DATADIR = local_test_datadir |
| |
| print(f"Linking {tmp_datadir} to {LOCAL_DATADIR} (we are in the test environment)") |
| LOCAL_DATADIR.parent.mkdir(parents=True, exist_ok=True) |
| LOCAL_DATADIR.symlink_to(tmp_datadir) |
| else: |
| LOCAL_DATADIR = local_val_datadir |
| print(f"Using {LOCAL_DATADIR} as the data directory (we are running locally)") |
| |
| |
| |
| assert LOCAL_DATADIR.exists(), f"Data directory {LOCAL_DATADIR} does not exist" |
| return LOCAL_DATADIR |
| |
| |
| |
| |
| import importlib |
| from pathlib import Path |
| import subprocess |
|
|
| def download_package(package_name, path_to_save='packages'): |
| """ |
| Downloads a package using pip and saves it to a specified directory. |
| |
| Parameters: |
| package_name (str): The name of the package to download. |
| path_to_save (str): The path to the directory where the package will be saved. |
| """ |
| try: |
| |
| subprocess.check_call([subprocess.sys.executable, "-m", "pip", "download", package_name, |
| "-d", str(Path(path_to_save)/package_name), |
| "--platform", "manylinux1_x86_64", |
| "--python-version", "38", |
| "--only-binary=:all:"]) |
| print(f'Package "{package_name}" downloaded successfully') |
| except subprocess.CalledProcessError as e: |
| print(f'Failed to downloaded package "{package_name}". Error: {e}') |
| |
| |
| def install_package_from_local_file(package_name, folder='packages'): |
| """ |
| Installs a package from a local .whl file or a directory containing .whl files using pip. |
| |
| Parameters: |
| path_to_file_or_directory (str): The path to the .whl file or the directory containing .whl files. |
| """ |
| try: |
| pth = str(Path(folder) / package_name) |
| subprocess.check_call([subprocess.sys.executable, "-m", "pip", "install", |
| "--no-index", |
| "--find-links", pth, |
| package_name]) |
| print(f"Package installed successfully from {pth}") |
| except subprocess.CalledProcessError as e: |
| print(f"Failed to install package from {pth}. Error: {e}") |
| |
| |
| def importt(module_name, as_name=None): |
| """ |
| Imports a module and returns it. |
| |
| Parameters: |
| module_name (str): The name of the module to import. |
| as_name (str): The name to use for the imported module. If None, the original module name will be used. |
| |
| Returns: |
| The imported module. |
| """ |
| for _ in range(2): |
| try: |
| if as_name is None: |
| print(f'imported {module_name}') |
| return importlib.import_module(module_name) |
| else: |
| print(f'imported {module_name} as {as_name}') |
| return importlib.import_module(module_name, as_name) |
| except ModuleNotFoundError as e: |
| install_package_from_local_file(module_name) |
| print(f"Failed to import module {module_name}. Error: {e}") |
| |
| |
| def prepare_submission(): |
| |
| if Path('requirements.txt').exists(): |
| print('downloading packages from requirements.txt') |
| Path('packages').mkdir(exist_ok=True) |
| with open('requirements.txt') as f: |
| packages = f.readlines() |
| for p in packages: |
| download_package(p.strip()) |
| |
| |
| print('all packages downloaded. Don\'t foget to include the packages in the submission by adding them with git lfs.') |
| |
|
|
|
|
| |
| import contextlib |
| import tempfile |
| from pathlib import Path |
|
|
| @contextlib.contextmanager |
| def working_directory(path): |
| """Changes working directory and returns to previous on exit.""" |
| prev_cwd = Path.cwd() |
| os.chdir(path) |
| try: |
| yield |
| finally: |
| os.chdir(prev_cwd) |
| |
| @contextlib.contextmanager |
| def temp_working_directory(): |
| with tempfile.TemporaryDirectory(dir='.') as D: |
| with working_directory(D): |
| yield |
|
|
|
|
| |
| def proc(row, split='train'): |
| |
| |
| |
| out = {} |
| for k, v in row.items(): |
| colname = k.split('.')[0] |
| if colname in {'ade20k', 'depthcm', 'gestalt'}: |
| if colname in out: |
| out[colname].append(v) |
| else: |
| out[colname] = [v] |
| elif colname in {'wireframe', 'mesh'}: |
| |
| out.update({a: b for a,b in v.items()}) |
| elif colname in 'kr': |
| out[colname.upper()] = v |
| else: |
| out[colname] = v |
| |
| return Sample(out) |
|
|
|
|
| class Sample(Dict): |
| def __repr__(self): |
| return str({k: v.shape if hasattr(v, 'shape') else [type(v[0])] if isinstance(v, list) else type(v) for k,v in self.items()}) |
|
|
| |
| |
| def get_params(): |
| exmaple_param_dict = { |
| "competition_id": "usm3d/S23DR", |
| "competition_type": "script", |
| "metric": "custom", |
| "token": "hf_**********************************", |
| "team_id": "local-test-team_id", |
| "submission_id": "local-test-submission_id", |
| "submission_id_col": "__key__", |
| "submission_cols": [ |
| "__key__", |
| "wf_edges", |
| "wf_vertices", |
| "edge_semantics" |
| ], |
| "submission_rows": 180, |
| "output_path": ".", |
| "submission_repo": "<THE HF MODEL ID of THIS REPO", |
| "time_limit": 7200, |
| "dataset": "usm3d/usm-test-data-x", |
| "submission_filenames": [ |
| "submission.parquet" |
| ] |
| } |
| |
| param_path = Path('params.json') |
| |
| if not param_path.exists(): |
| print('params.json not found (this means we probably aren\'t in the test env). Using example params.') |
| params = exmaple_param_dict |
| else: |
| print('found params.json (this means we are probably in the test env). Using params from file.') |
| with param_path.open() as f: |
| params = json.load(f) |
| print(params) |
| return params |
|
|
|
|
|
|
| import webdataset as wds |
| import numpy as np |
|
|
| def get_dataset(decode='pil', proc=proc, split='train', dataset_type='webdataset'): |
| if LOCAL_DATADIR is None: |
| raise ValueError('LOCAL_DATADIR is not set. Please run setup() first.') |
| |
| local_dir = Path(LOCAL_DATADIR) |
| if split != 'all': |
| local_dir = local_dir / split |
| |
| paths = [str(p) for p in local_dir.rglob('*.tar.gz')] |
| |
| dataset = wds.WebDataset(paths) |
| if decode is not None: |
| dataset = dataset.decode(decode) |
| else: |
| dataset = dataset.decode() |
| |
| dataset = dataset.map(proc) |
| |
| if dataset_type == 'webdataset': |
| return dataset |
| |
| if dataset_type == 'hf': |
| import datasets |
| from datasets import Features, Value, Sequence, Image, Array2D |
| |
| if split == 'train': |
| return datasets.IterableDataset.from_generator(lambda: dataset.iterator()) |
| elif split == 'val': |
| return datasets.IterableDataset.from_generator(lambda: dataset.iterator()) |
|
|
| |
| |