| import logging |
| import random |
| from copy import deepcopy |
| from typing import Dict, Optional |
|
|
| import torch |
| import torchvision.transforms.v2.functional as TVF |
| from einops import rearrange |
| from torch.utils.data import DataLoader |
|
|
| logger = logging.getLogger("__main__") |
|
|
|
|
| def get_embeddings(data_loader, model, device, subsample_tokens: Optional[float] = None): |
| embeddings = [] |
| labels = [] |
| if subsample_tokens: |
| print(f"Subsampling tokens with ratio {subsample_tokens}") |
|
|
| model = model.eval() |
| with torch.no_grad(): |
| for batch in data_loader: |
| batch_labels = batch.pop("target") |
| if "s1" in batch: |
| batch["s1"] = batch["s1"].to(device).to(torch.bfloat16) |
| if "s2" in batch: |
| batch["s2"] = batch["s2"].to(device).to(torch.bfloat16) |
| if "months" in batch: |
| batch["months"] = batch["months"].to(device).long() |
|
|
| with torch.cuda.amp.autocast(dtype=torch.bfloat16): |
| batch_embeddings = model(**batch) |
|
|
| if subsample_tokens is not None: |
| if len(batch_embeddings.shape) < 3: |
| raise ValueError("subsample tokens only works for segmentation tasks") |
| num_tokens_per_instance = batch_embeddings.shape[1] |
| num_instances_to_keep = int(num_tokens_per_instance * subsample_tokens) |
| sampled_indices = torch.randperm(num_tokens_per_instance)[:num_instances_to_keep] |
| batch_embeddings = batch_embeddings[:, sampled_indices] |
|
|
| tokens_per_dim = int(num_tokens_per_instance**0.5) |
| pixels_per_token_dim = int(batch_labels.shape[1] / tokens_per_dim) |
|
|
| batch_labels_per_token = rearrange( |
| batch_labels, |
| "b (t_h p_h) (t_w p_w) -> b (t_h t_w) (p_h p_w)", |
| t_h=tokens_per_dim, |
| t_w=tokens_per_dim, |
| p_h=pixels_per_token_dim, |
| p_w=pixels_per_token_dim, |
| ) |
| batch_labels = batch_labels_per_token[:, sampled_indices] |
|
|
| embeddings.append(batch_embeddings.to(torch.bfloat16).cpu()) |
| labels.append(batch_labels) |
|
|
| return torch.cat(embeddings, dim=0), torch.cat(labels, dim=0) |
|
|
|
|
| class DownstreamAugs(object): |
| """ |
| For now, lets have no parameters |
| Choose 1 of 8 transformations and apply it to space_x and the segmentation map (if needed) |
| """ |
|
|
| def __init__(self, enabled: bool): |
| self.enabled = enabled |
| self.transformations = [ |
| self.no_transform, |
| self.rotate_90, |
| self.rotate_180, |
| self.rotate_270, |
| self.hflip, |
| self.vflip, |
| self.hflip_rotate_90, |
| self.vflip_rotate_90, |
| ] |
|
|
| def no_transform(self, x): |
| return x |
|
|
| def rotate_90(self, x): |
| return TVF.rotate(x, 90) |
|
|
| def rotate_180(self, x): |
| return TVF.rotate(x, 180) |
|
|
| def rotate_270(self, x): |
| return TVF.rotate(x, 270) |
|
|
| def hflip(self, x): |
| return TVF.hflip(x) |
|
|
| def vflip(self, x): |
| return TVF.vflip(x) |
|
|
| def hflip_rotate_90(self, x): |
| return TVF.hflip(TVF.rotate(x, 90)) |
|
|
| def vflip_rotate_90(self, x): |
| return TVF.vflip(TVF.rotate(x, 90)) |
|
|
| def apply(self, image, target, task_type): |
| assert task_type in ["cls", "seg"] |
| |
| |
| if not self.enabled: |
| return image, target |
|
|
| |
| transformation = random.choice(self.transformations) |
|
|
| |
| image = rearrange(image, "h w c -> c h w") |
| image = transformation(image) |
| image = rearrange(image, "c h w -> h w c") |
|
|
| if task_type == "cls": |
| return image, target |
| else: |
| |
| assert target.shape[-1] == image.shape[-1] |
| assert target.shape[-2] == image.shape[-2] |
| target = rearrange(target, "h w -> 1 h w") |
| target = transformation(target) |
| target = rearrange(target, "1 h w -> h w") |
| return image, target |
|
|
|
|
| def get_loaders( |
| benchmark, |
| config, |
| model_name, |
| batch_size, |
| num_workers, |
| eval_type, |
| train_partition: Optional[str] = None, |
| valtest_partition: Optional[str] = None, |
| norm_ops: Optional[Dict] = None, |
| ): |
| use_train_augs = True if eval_type == "FT" else False |
|
|
| dataclass_kwargs = deepcopy(benchmark["kwargs"]) |
| if norm_ops is None: |
| dataclass_kwargs["norm_operation"] = config["models"][model_name] |
| else: |
| dataclass_kwargs["norm_operation"] = norm_ops |
|
|
| train_kwargs = deepcopy(dataclass_kwargs) |
| valtest_kwargs = deepcopy(dataclass_kwargs) |
| if train_partition is not None: |
| train_kwargs["partition"] = train_partition |
| if valtest_partition is None: |
| valtest_partition = "default" |
| valtest_kwargs["partition"] = valtest_partition |
| elif valtest_partition: |
| raise ValueError("Shouldn't have not None val_partition but None train_partiton") |
|
|
| return { |
| "train": DataLoader( |
| benchmark["class"]( |
| **train_kwargs, |
| split="train", |
| augmentation=DownstreamAugs(use_train_augs), |
| ), |
| batch_size=batch_size, |
| shuffle=False, |
| num_workers=num_workers, |
| ), |
| "valid": DataLoader( |
| benchmark["class"]( |
| **valtest_kwargs, |
| split="valid", |
| augmentation=DownstreamAugs(False), |
| ), |
| batch_size=batch_size, |
| shuffle=False, |
| num_workers=num_workers, |
| ), |
| "test": DataLoader( |
| benchmark["class"]( |
| **valtest_kwargs, |
| split="test", |
| augmentation=DownstreamAugs(False), |
| ), |
| batch_size=batch_size, |
| shuffle=False, |
| num_workers=num_workers, |
| ), |
| } |
|
|