Spaces:
Sleeping
Sleeping
| import os | |
| import pprint | |
| from typing import Any, List, Optional | |
| import argparse | |
| from huggingface_hub import hf_hub_download | |
| import pyrallis | |
| from pytorch_lightning.utilities.types import STEP_OUTPUT | |
| import torch | |
| import os | |
| from PIL import Image | |
| import numpy as np | |
| import tifffile | |
| import skimage.io as io | |
| from config import RunConfig | |
| from _utils import attn_utils_new as attn_utils | |
| from _utils.attn_utils import AttentionStore | |
| from _utils.misc_helper import * | |
| from torch.autograd import Variable | |
| import itertools | |
| from accelerate import Accelerator | |
| import torch.nn.functional as F | |
| from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline | |
| from tqdm import tqdm | |
| import torch.nn as nn | |
| import matplotlib.pyplot as plt | |
| import cv2 | |
| import warnings | |
| warnings.filterwarnings("ignore", category=UserWarning) | |
| import pytorch_lightning as pl | |
| from _utils.load_models import load_stable_diffusion_model | |
| from models.model import Counting_with_SD_features_track as Counting | |
| from models.enc_model.loca_args import get_argparser as loca_get_argparser | |
| from models.enc_model.loca import build_model as build_loca_model | |
| import time | |
| from _utils.seg_eval import * | |
| from models.tra_post_model.trackastra.model import Trackastra | |
| from models.tra_post_model.trackastra.model import TrackingTransformer | |
| from models.tra_post_model.trackastra.utils import ( | |
| blockwise_causal_norm, | |
| blockwise_sum, | |
| normalize, | |
| ) | |
| from models.tra_post_model.trackastra.data import build_windows_sd, get_features, load_tiff_timeseries | |
| from models.tra_post_model.trackastra.tracking import TrackGraph, build_graph, track_greedy, graph_to_ctc | |
| from _utils.track_args import parse_train_args as get_track_args | |
| import torchvision.transforms as T | |
| from pathlib import Path | |
| import dask.array as da | |
| from typing import Dict, List, Optional, Union, Literal | |
| from scipy.sparse import SparseEfficiencyWarning, csr_array | |
| import tracemalloc | |
| import gc | |
| # from memory_profiler import profile | |
| from _utils.load_track_data import load_track_images | |
| SCALE = 1 | |
| def get_instance_boxes(mask): | |
| # Convert to int64 if needed | |
| if mask.dtype != torch.long: | |
| mask = mask.to(torch.long) | |
| boxes = [] | |
| instance_ids = torch.unique(mask) | |
| instance_ids = instance_ids[instance_ids != 0] # skip background | |
| for inst_id in instance_ids: | |
| inst_mask = mask == inst_id | |
| y_indices, x_indices = torch.where(inst_mask) | |
| if len(x_indices) == 0 or len(y_indices) == 0: | |
| continue | |
| x_min = torch.min(x_indices).item() | |
| x_max = torch.max(x_indices).item() | |
| y_min = torch.min(y_indices).item() | |
| y_max = torch.max(y_indices).item() | |
| boxes.append([x_min, y_min, x_max, y_max]) | |
| boxes = torch.tensor(boxes, dtype=torch.float32) | |
| return boxes | |
| class TrackingModule(pl.LightningModule): | |
| def __init__(self, use_box=False): | |
| super().__init__() | |
| self.use_box = use_box | |
| self.config = RunConfig() # config for stable diffusion | |
| self.initialize_model() | |
| def initialize_model(self): | |
| # load loca model | |
| loca_args = loca_get_argparser().parse_args() | |
| self.loca_model = build_loca_model(loca_args) | |
| # weights = torch.load("ckpt/loca_few_shot.pt")["model"] | |
| # weights = {k.replace("module","") : v for k, v in weights.items()} | |
| # self.loca_model.load_state_dict(weights, strict=False) | |
| # del weights | |
| self.counting_adapter = Counting(scale_factor=SCALE) | |
| # if os.path.isfile(self.args.adapter_weight): | |
| # adapter_weight = torch.load(self.args.adapter_weight,map_location=torch.device('cpu')) | |
| # self.counting_adapter.load_state_dict(adapter_weight, strict=False) | |
| ### load stable diffusion and its controller | |
| self.stable = load_stable_diffusion_model(config=self.config) | |
| self.noise_scheduler = self.stable.scheduler | |
| self.controller = AttentionStore(max_size=64) | |
| attn_utils.register_attention_control(self.stable, self.controller) | |
| attn_utils.register_hier_output(self.stable) | |
| ##### initialize token_emb ##### | |
| placeholder_token = "<task-prompt>" | |
| self.task_token = "repetitive objects" | |
| # Add the placeholder token in tokenizer | |
| num_added_tokens = self.stable.tokenizer.add_tokens(placeholder_token) | |
| if num_added_tokens == 0: | |
| raise ValueError( | |
| f"The tokenizer already contains the token {placeholder_token}. Please pass a different" | |
| " `placeholder_token` that is not already in the tokenizer." | |
| ) | |
| try: | |
| # task_embed_from_pretrain = torch.load("pretrained/task_embed.pth") | |
| task_embed_from_pretrain = hf_hub_download( | |
| repo_id="phoebe777777/111", | |
| filename="task_embed.pth", | |
| token=None, | |
| force_download=False | |
| ) | |
| placeholder_token_id = self.stable.tokenizer.convert_tokens_to_ids(placeholder_token) | |
| self.stable.text_encoder.resize_token_embeddings(len(self.stable.tokenizer)) | |
| token_embeds = self.stable.text_encoder.get_input_embeddings().weight.data | |
| token_embeds[placeholder_token_id] = task_embed_from_pretrain | |
| except: | |
| #@title Get token ids for our placeholder and initializer token. This code block will complain if initializer string is not a single token | |
| # Convert the initializer_token, placeholder_token to ids | |
| initializer_token = "track" | |
| token_ids = self.stable.tokenizer.encode(initializer_token, add_special_tokens=False) | |
| # Check if initializer_token is a single token or a sequence of tokens | |
| if len(token_ids) > 1: | |
| raise ValueError("The initializer token must be a single token.") | |
| initializer_token_id = token_ids[0] | |
| placeholder_token_id = self.stable.tokenizer.convert_tokens_to_ids(placeholder_token) | |
| self.stable.text_encoder.resize_token_embeddings(len(self.stable.tokenizer)) | |
| token_embeds = self.stable.text_encoder.get_input_embeddings().weight.data | |
| token_embeds[placeholder_token_id] = token_embeds[initializer_token_id] | |
| # others | |
| self.placeholder_token = placeholder_token | |
| self.placeholder_token_id = placeholder_token_id | |
| # tracking model | |
| # fpath = Path("models/tra_post_model/trackastra/.models/general_2d/model.pt") | |
| fpath = Path("_utils/config.yaml") | |
| args_ = get_track_args() | |
| model = TrackingTransformer.from_cfg( | |
| cfg_path=fpath, | |
| args=args_, | |
| ) | |
| # model = TrackingTransformer.from_folder( | |
| # Path(*fpath.parts[:-1]), | |
| # args=args_, | |
| # checkpoint_path=Path(*fpath.parts[-1:]), | |
| # ) | |
| self.track_model = model | |
| self.track_args = args_ | |
| def move_to_device(self, device): | |
| self.stable.to(device) | |
| # if self.loca_model is not None and self.counting_adapter is not None: | |
| # self.loca_model.to(device) | |
| # self.counting_adapter.to(device) | |
| self.counting_adapter.to(device) | |
| # self.dino.to(device) | |
| self.loca_model.to(device) | |
| self.track_model.to(device) | |
| self.to(device) | |
| def on_train_start(self) -> None: | |
| device = self.device | |
| dtype = self.dtype | |
| self.stable.to(device,dtype) | |
| def on_validation_start(self) -> None: | |
| device = self.device | |
| dtype = self.dtype | |
| self.stable.to(device,dtype) | |
| def forward(self, data): | |
| input_image_stable = data["image_stable"] | |
| boxes = data["boxes"] | |
| input_image = data["img_enc"] | |
| mask = data["mask"] | |
| latents = self.stable.vae.encode(input_image_stable).latent_dist.sample().detach() | |
| latents = latents * 0.18215 | |
| # Sample noise that we'll add to the latents | |
| noise = torch.randn_like(latents) | |
| bsz = latents.shape[0] | |
| timesteps = torch.tensor([20], device=latents.device).long() | |
| noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps) | |
| input_ids_ = self.stable.tokenizer( | |
| self.placeholder_token, | |
| # "object", | |
| padding="max_length", | |
| truncation=True, | |
| max_length=self.stable.tokenizer.model_max_length, | |
| return_tensors="pt", | |
| ) | |
| input_ids = input_ids_["input_ids"].to(self.device) | |
| attention_mask = input_ids_["attention_mask"].to(self.device) | |
| encoder_hidden_states = self.stable.text_encoder(input_ids, attention_mask)[0] | |
| encoder_hidden_states = encoder_hidden_states.repeat(bsz, 1, 1) | |
| time1 = time.time() | |
| input_image = input_image.to(self.device) | |
| boxes = boxes.to(self.device) | |
| loca_out = self.loca_model.forward_before_reg(input_image, boxes) | |
| loca_feature_bf_regression = loca_out["feature_bf_regression"] | |
| # time2 = time.time() | |
| task_loc_idx = torch.nonzero(input_ids == self.placeholder_token_id) | |
| adapted_emb = self.counting_adapter.adapter(loca_feature_bf_regression, boxes) # shape [1, 768] | |
| if task_loc_idx.shape[0] == 0: | |
| encoder_hidden_states[0,2,:] = adapted_emb.squeeze() # 放在task prompt下一位 | |
| else: | |
| encoder_hidden_states[:,task_loc_idx[0, 1]+1,:] = adapted_emb.squeeze() # 放在task prompt下一位 | |
| # Predict the noise residual | |
| noise_pred, feature_list = self.stable.unet(noisy_latents, timesteps, encoder_hidden_states) | |
| time3 = time.time() | |
| noise_pred = noise_pred.sample | |
| attention_store = self.controller.attention_store | |
| # print(time2-time1, time3-time2) | |
| attention_maps = [] | |
| exemplar_attention_maps = [] | |
| cross_self_task_attn_maps = [] | |
| cross_self_exe_attn_maps = [] | |
| # only use 64x64 self-attention | |
| self_attn_aggregate = attn_utils.aggregate_attention( # [res, res, 4096] | |
| prompts=[self.config.prompt for i in range(bsz)], # 这里要改么 | |
| attention_store=self.controller, | |
| res=64, | |
| from_where=("up", "down"), | |
| is_cross=False, | |
| select=0 | |
| ) | |
| self_attn_aggregate32 = attn_utils.aggregate_attention( # [res, res, 4096] | |
| prompts=[self.config.prompt for i in range(bsz)], # 这里要改么 | |
| attention_store=self.controller, | |
| res=32, | |
| from_where=("up", "down"), | |
| is_cross=False, | |
| select=0 | |
| ) | |
| self_attn_aggregate16 = attn_utils.aggregate_attention( # [res, res, 4096] | |
| prompts=[self.config.prompt for i in range(bsz)], # 这里要改么 | |
| attention_store=self.controller, | |
| res=16, | |
| from_where=("up", "down"), | |
| is_cross=False, | |
| select=0 | |
| ) | |
| # cross attention | |
| for res in [32, 16]: | |
| attn_aggregate = attn_utils.aggregate_attention( # [res, res, 77] | |
| prompts=[self.config.prompt for i in range(bsz)], # 这里要改么 | |
| attention_store=self.controller, | |
| res=res, | |
| from_where=("up", "down"), | |
| is_cross=True, | |
| select=0 | |
| ) | |
| task_attn_ = attn_aggregate[:, :, 1].unsqueeze(0).unsqueeze(0) # [1, 1, res, res] | |
| attention_maps.append(task_attn_) | |
| exemplar_attns = attn_aggregate[:, :, 2].unsqueeze(0).unsqueeze(0) # 取exemplar的attn | |
| exemplar_attention_maps.append(exemplar_attns) | |
| scale_factors = [(64 // attention_maps[i].shape[-1]) for i in range(len(attention_maps))] | |
| attns = torch.cat([F.interpolate(attention_maps[i_], scale_factor=scale_factors[i_], mode="bilinear") for i_ in range(len(attention_maps))]) | |
| task_attn_64 = torch.mean(attns, dim=0, keepdim=True) | |
| scale_factors = [(64 // exemplar_attention_maps[i].shape[-1]) for i in range(len(exemplar_attention_maps))] | |
| attns = torch.cat([F.interpolate(exemplar_attention_maps[i_], scale_factor=scale_factors[i_], mode="bilinear") for i_ in range(len(exemplar_attention_maps))]) | |
| exemplar_attn_64 = torch.mean(attns, dim=0, keepdim=True) | |
| cross_self_task_attn = attn_utils.self_cross_attn(self_attn_aggregate, task_attn_64) | |
| cross_self_exe_attn = attn_utils.self_cross_attn(self_attn_aggregate, exemplar_attn_64) | |
| cross_self_task_attn_maps.append(cross_self_task_attn) | |
| cross_self_exe_attn_maps.append(cross_self_exe_attn) | |
| task_attn_64 = (task_attn_64 - task_attn_64.min()) / (task_attn_64.max() - task_attn_64.min() + 1e-6) | |
| cross_self_task_attn = (cross_self_task_attn - cross_self_task_attn.min()) / (cross_self_task_attn.max() - cross_self_task_attn.min() + 1e-6) | |
| exemplar_attn_64 = (exemplar_attn_64 - exemplar_attn_64.min()) / (exemplar_attn_64.max() - exemplar_attn_64.min() + 1e-6) | |
| cross_self_exe_attn = (cross_self_exe_attn - cross_self_exe_attn.min()) / (cross_self_exe_attn.max() - cross_self_exe_attn.min() + 1e-6) | |
| attn_stack = [task_attn_64 / 2, cross_self_task_attn / 2, exemplar_attn_64, cross_self_exe_attn] | |
| attn_stack = torch.cat(attn_stack, dim=1) | |
| attn_after_new_regressor, loss = self.counting_adapter.regressor(input_image, attn_stack, feature_list, mask.cpu().numpy(), training=False) # 直接用自己的 | |
| return { | |
| "attn_after_new_regressor":attn_after_new_regressor, | |
| "task_attn_64":task_attn_64, | |
| "cross_self_task_attn":cross_self_task_attn, | |
| "exemplar_attn_64": exemplar_attn_64, | |
| "cross_self_exe_attn": cross_self_exe_attn, | |
| "noise_pred":noise_pred, | |
| "noise":noise, | |
| "self_attn_aggregate":self_attn_aggregate, | |
| "self_attn_aggregate32":self_attn_aggregate32, | |
| "self_attn_aggregate16":self_attn_aggregate16, | |
| "loss": loss | |
| } | |
| def forward_sd(self, input_image_stable, input_image, boxes, height, width, mask=None): | |
| input_image_stable = input_image_stable.to(self.device) | |
| # density = data["density"] | |
| if boxes is not None: | |
| boxes = boxes.to(self.device) | |
| input_image = input_image.to(self.device) | |
| if mask is not None: | |
| mask = mask.to(self.device) | |
| else: | |
| mask = torch.zeros((input_image.shape[0], 1, input_image.shape[2], input_image.shape[3])).to(self.device) | |
| latents = self.stable.vae.encode(input_image_stable).latent_dist.sample().detach() | |
| latents = latents * 0.18215 | |
| # Sample noise that we'll add to the latents | |
| noise = torch.randn_like(latents) | |
| bsz = latents.shape[0] | |
| timesteps = torch.tensor([20], device=latents.device).long() | |
| noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps) | |
| input_ids_ = self.stable.tokenizer( | |
| self.placeholder_token + " " + self.task_token, | |
| # "object", | |
| padding="max_length", | |
| truncation=True, | |
| max_length=self.stable.tokenizer.model_max_length, | |
| return_tensors="pt", | |
| ) | |
| input_ids = input_ids_["input_ids"].to(self.device) | |
| attention_mask = input_ids_["attention_mask"].to(self.device) | |
| encoder_hidden_states = self.stable.text_encoder(input_ids, attention_mask)[0] | |
| encoder_hidden_states = encoder_hidden_states.repeat(bsz, 1, 1) | |
| if boxes is not None and not self.training: | |
| if self.adapt_emb is None: | |
| loca_out_ = self.loca_model.forward_before_reg(input_image, boxes) | |
| loca_feature_bf_regression_ = loca_out_["feature_bf_regression"] | |
| adapted_emb = self.counting_adapter.adapter(loca_feature_bf_regression_, boxes) # shape [1, 768] | |
| else: | |
| adapted_emb = self.adapt_emb.to(self.device) | |
| task_loc_idx = torch.nonzero(input_ids == self.placeholder_token_id) | |
| if task_loc_idx.shape[0] == 0: | |
| encoder_hidden_states[0,5,:] = adapted_emb.squeeze() # 放在task prompt下一位 | |
| else: | |
| encoder_hidden_states[:,task_loc_idx[0, 1]+4,:] = adapted_emb.squeeze() # 放在task prompt下一位 | |
| # Predict the noise residual | |
| noise_pred, feature_list = self.stable.unet(noisy_latents, timesteps, encoder_hidden_states) | |
| noise_pred = noise_pred.sample | |
| attention_store = self.controller.attention_store | |
| attention_maps = [] | |
| exemplar_attention_maps = [] | |
| exemplar_attention_maps1 = [] | |
| exemplar_attention_maps2 = [] | |
| exemplar_attention_maps3 = [] | |
| exemplar_attention_maps4 = [] | |
| cross_self_task_attn_maps = [] | |
| cross_self_exe_attn_maps = [] | |
| # only use 64x64 self-attention | |
| self_attn_aggregate = attn_utils.aggregate_attention( # [res, res, 4096] | |
| prompts=[self.config.prompt for i in range(bsz)], # 这里要改么 | |
| attention_store=self.controller, | |
| res=64, | |
| from_where=("up", "down"), | |
| is_cross=False, | |
| select=0 | |
| ) | |
| # cross attention | |
| for res in [32, 16]: | |
| attn_aggregate = attn_utils.aggregate_attention( # [res, res, 77] | |
| prompts=[self.config.prompt for i in range(bsz)], # 这里要改么 | |
| attention_store=self.controller, | |
| res=res, | |
| from_where=("up", "down"), | |
| is_cross=True, | |
| select=0 | |
| ) | |
| task_attn_ = attn_aggregate[:, :, 1].unsqueeze(0).unsqueeze(0) # [1, 1, res, res] | |
| attention_maps.append(task_attn_) | |
| # if self.boxes is not None and not self.training: | |
| exemplar_attns1 = attn_aggregate[:, :, 2].unsqueeze(0).unsqueeze(0) # 取exemplar的attn | |
| exemplar_attention_maps1.append(exemplar_attns1) | |
| exemplar_attns2 = attn_aggregate[:, :, 3].unsqueeze(0).unsqueeze(0) # 取exemplar的attn | |
| exemplar_attention_maps2.append(exemplar_attns2) | |
| exemplar_attns3 = attn_aggregate[:, :, 4].unsqueeze(0).unsqueeze(0) # 取exemplar的attn | |
| exemplar_attention_maps3.append(exemplar_attns3) | |
| exemplar_attns4 = attn_aggregate[:, :, 5].unsqueeze(0).unsqueeze(0) # 取exemplar的attn | |
| exemplar_attention_maps4.append(exemplar_attns4) | |
| scale_factors = [(64 // attention_maps[i].shape[-1]) for i in range(len(attention_maps))] | |
| attns = torch.cat([F.interpolate(attention_maps[i_], scale_factor=scale_factors[i_], mode="bilinear") for i_ in range(len(attention_maps))]) | |
| task_attn_64 = torch.mean(attns, dim=0, keepdim=True) | |
| cross_self_task_attn = attn_utils.self_cross_attn(self_attn_aggregate, task_attn_64) | |
| cross_self_task_attn_maps.append(cross_self_task_attn) | |
| # if not self.training: | |
| scale_factors = [(64 // exemplar_attention_maps1[i].shape[-1]) for i in range(len(exemplar_attention_maps1))] | |
| attns = torch.cat([F.interpolate(exemplar_attention_maps1[i_], scale_factor=scale_factors[i_], mode="bilinear") for i_ in range(len(exemplar_attention_maps1))]) | |
| exemplar_attn_64_1 = torch.mean(attns, dim=0, keepdim=True) | |
| scale_factors = [(64 // exemplar_attention_maps2[i].shape[-1]) for i in range(len(exemplar_attention_maps2))] | |
| attns = torch.cat([F.interpolate(exemplar_attention_maps2[i_], scale_factor=scale_factors[i_], mode="bilinear") for i_ in range(len(exemplar_attention_maps2))]) | |
| exemplar_attn_64_2 = torch.mean(attns, dim=0, keepdim=True) | |
| scale_factors = [(64 // exemplar_attention_maps3[i].shape[-1]) for i in range(len(exemplar_attention_maps3))] | |
| attns = torch.cat([F.interpolate(exemplar_attention_maps3[i_], scale_factor=scale_factors[i_], mode="bilinear") for i_ in range(len(exemplar_attention_maps3))]) | |
| exemplar_attn_64_3 = torch.mean(attns, dim=0, keepdim=True) | |
| if boxes is not None: | |
| scale_factors = [(64 // exemplar_attention_maps4[i].shape[-1]) for i in range(len(exemplar_attention_maps4))] | |
| attns = torch.cat([F.interpolate(exemplar_attention_maps4[i_], scale_factor=scale_factors[i_], mode="bilinear") for i_ in range(len(exemplar_attention_maps4))]) | |
| exemplar_attn_64_4 = torch.mean(attns, dim=0, keepdim=True) | |
| exes = [] | |
| cross_exes = [] | |
| cross_self_exe_attn1 = attn_utils.self_cross_attn(self_attn_aggregate, exemplar_attn_64_1) | |
| cross_self_exe_attn2 = attn_utils.self_cross_attn(self_attn_aggregate, exemplar_attn_64_2) | |
| cross_self_exe_attn3 = attn_utils.self_cross_attn(self_attn_aggregate, exemplar_attn_64_3) | |
| # # average | |
| exemplar_attn_64_1 = (exemplar_attn_64_1 - exemplar_attn_64_1.min()) / (exemplar_attn_64_1.max() - exemplar_attn_64_1.min() + 1e-6) | |
| exemplar_attn_64_2 = (exemplar_attn_64_2 - exemplar_attn_64_2.min()) / (exemplar_attn_64_2.max() - exemplar_attn_64_2.min() + 1e-6) | |
| exemplar_attn_64_3 = (exemplar_attn_64_3 - exemplar_attn_64_3.min()) / (exemplar_attn_64_3.max() - exemplar_attn_64_3.min() + 1e-6) | |
| cross_self_exe_attn1 = (cross_self_exe_attn1 - cross_self_exe_attn1.min()) / (cross_self_exe_attn1.max() - cross_self_exe_attn1.min() + 1e-6) | |
| cross_self_exe_attn2 = (cross_self_exe_attn2 - cross_self_exe_attn2.min()) / (cross_self_exe_attn2.max() - cross_self_exe_attn2.min() + 1e-6) | |
| cross_self_exe_attn3 = (cross_self_exe_attn3 - cross_self_exe_attn3.min()) / (cross_self_exe_attn3.max() - cross_self_exe_attn3.min() + 1e-6) | |
| exes = [exemplar_attn_64_1, exemplar_attn_64_2, exemplar_attn_64_3] | |
| cross_exes = [cross_self_exe_attn1, cross_self_exe_attn2, cross_self_exe_attn3] | |
| if boxes is not None: | |
| cross_self_exe_attn4 = attn_utils.self_cross_attn(self_attn_aggregate, exemplar_attn_64_4) | |
| exemplar_attn_64_4 = (exemplar_attn_64_4 - exemplar_attn_64_4.min()) / (exemplar_attn_64_4.max() - exemplar_attn_64_4.min() + 1e-6) | |
| cross_self_exe_attn4 = (cross_self_exe_attn4 - cross_self_exe_attn4.min()) / (cross_self_exe_attn4.max() - cross_self_exe_attn4.min() + 1e-6) | |
| exes.append(exemplar_attn_64_4) | |
| cross_exes.append(cross_self_exe_attn4) | |
| exemplar_attn_64 = sum(exes) / len(exes) | |
| cross_self_exe_attn = sum(cross_exes) / len(cross_exes) | |
| if self.use_box: | |
| attn_stack = [task_attn_64 / 2, cross_self_task_attn / 2, exemplar_attn_64, cross_self_exe_attn] | |
| else: | |
| attn_stack = [exemplar_attn_64 / 2, cross_self_exe_attn / 2, exemplar_attn_64, cross_self_exe_attn] | |
| attn_stack = torch.cat(attn_stack, dim=1) | |
| attn_after_new_regressor, loss, _ = self.counting_adapter.regressor.forward_seg(input_image, attn_stack, feature_list, mask.cpu().numpy(), self.training) | |
| if not self.training: | |
| pred_mask = attn_after_new_regressor.detach().cpu() | |
| pred_boxes = get_instance_boxes(pred_mask.squeeze()) | |
| self.boxes = pred_boxes.unsqueeze(0) | |
| if pred_boxes.shape[0] == 0: | |
| print("No instances detected in the predicted mask.") | |
| self.adapt_emb = adapted_emb.detach().cpu() # reuse emb | |
| else: | |
| pred_boxes = pred_boxes.unsqueeze(0).to(self.device) | |
| loca_out_ = self.loca_model.forward_before_reg(input_image, pred_boxes) | |
| loca_feature_bf_regression_ = loca_out_["feature_bf_regression"] | |
| adapted_emb_ = self.counting_adapter.adapter(loca_feature_bf_regression_, pred_boxes) # shape [1, 768] | |
| self.adapt_emb = adapted_emb_.detach().cpu() | |
| # resize to original image size | |
| mask_np = attn_after_new_regressor.squeeze().detach().cpu().numpy() | |
| mask_resized = cv2.resize(mask_np, (width, height), interpolation=cv2.INTER_NEAREST) | |
| return mask_resized | |
| def forward_boxes(self, input_image_stable, boxes, input_image): | |
| latents = self.stable.vae.encode(input_image_stable).latent_dist.sample().detach() | |
| latents = latents * 0.18215 | |
| # Sample noise that we'll add to the latents | |
| noise = torch.randn_like(latents) | |
| bsz = latents.shape[0] | |
| timesteps = torch.tensor([20], device=latents.device).long() | |
| noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps) | |
| input_ids_ = self.stable.tokenizer( | |
| self.placeholder_token, | |
| # "object", | |
| padding="max_length", | |
| truncation=True, | |
| max_length=self.stable.tokenizer.model_max_length, | |
| return_tensors="pt", | |
| ) | |
| input_ids = input_ids_["input_ids"].to(self.device) | |
| attention_mask = input_ids_["attention_mask"].to(self.device) | |
| encoder_hidden_states = self.stable.text_encoder(input_ids, attention_mask)[0] | |
| encoder_hidden_states = encoder_hidden_states.repeat(bsz, 1, 1) | |
| time1 = time.time() | |
| input_image = input_image.to(self.device) | |
| boxes = boxes.to(self.device) | |
| loca_out = self.loca_model.forward_before_reg(input_image, boxes) | |
| loca_feature_bf_regression = loca_out["feature_bf_regression"] | |
| # time2 = time.time() | |
| task_loc_idx = torch.nonzero(input_ids == self.placeholder_token_id) | |
| adapted_emb = self.counting_adapter.adapter.forward_boxes(loca_feature_bf_regression, boxes) # shape [n_instance, 768] | |
| n_instance = adapted_emb.shape[0] | |
| n_forward = int(np.ceil(n_instance / 74)) # in total 75 prompts including 1 task prompt and 74 object prompts? | |
| task_cross_attention = [] | |
| instances_cross_attention = [] | |
| for n in range(n_forward): | |
| len_ = min(74, n_instance - n * 74) | |
| encoder_hidden_states[:,(task_loc_idx[0, 1]+1):(task_loc_idx[0, 1]+1+len_),:] = adapted_emb[n*74:n*74+len_].squeeze() # 放在task prompt下一位 | |
| # encoder_hidden_states: # [bsz, 77, 768], 其中第1位是task prompt的embedding, 第二位开始可以是object prompt的embedding, 最后一位应该保留原始embedding | |
| # Predict the noise residual | |
| noise_pred, feature_list = self.stable.unet(noisy_latents, timesteps, encoder_hidden_states) | |
| noise_pred = noise_pred.sample | |
| attention_maps = [] | |
| exemplar_attention_maps = [] | |
| # cross attention | |
| for res in [32, 16]: | |
| attn_aggregate = attn_utils.aggregate_attention( # [res, res, 77] | |
| prompts=[self.config.prompt for i in range(bsz)], # 这里要改么 | |
| attention_store=self.controller, | |
| res=res, | |
| from_where=("up", "down"), | |
| is_cross=True, | |
| select=0 | |
| ) | |
| task_attn_ = attn_aggregate[:, :, 1].unsqueeze(0).unsqueeze(0) # [1, 1, res, res] | |
| attention_maps.append(task_attn_) | |
| try: | |
| exemplar_attns = attn_aggregate[:, :, (task_loc_idx[0, 1]+1):(task_loc_idx[0, 1]+1+len_)].unsqueeze(0) # 取exemplar的attn | |
| except: | |
| print(n_instance, len_) | |
| exemplar_attns = torch.permute(exemplar_attns, (0, 3, 1, 2)) # [1, len_, res, res] | |
| exemplar_attention_maps.append(exemplar_attns) | |
| scale_factors = [(64 // attention_maps[i].shape[-1]) for i in range(len(attention_maps))] | |
| attns = torch.cat([F.interpolate(attention_maps[i_], scale_factor=scale_factors[i_], mode="bilinear") for i_ in range(len(attention_maps))]) | |
| task_attn_64 = torch.mean(attns, dim=0, keepdim=True) | |
| try: | |
| scale_factors = [(64 // exemplar_attention_maps[i].shape[-1]) for i in range(len(exemplar_attention_maps))] | |
| attns = torch.cat([F.interpolate(exemplar_attention_maps[i_], scale_factor=scale_factors[i_], mode="bilinear") for i_ in range(len(exemplar_attention_maps))]) | |
| except: | |
| print("exemplar_attention_maps shape mismatch, n_instance: {}, len_: {}".format(n_instance, len_)) | |
| print(exemplar_attention_maps[0].shape) | |
| print(exemplar_attention_maps[1].shape) | |
| print(exemplar_attention_maps[2].shape) | |
| exemplar_attn_64 = torch.mean(attns, dim=0, keepdim=True) | |
| task_attn_64 = (task_attn_64 - task_attn_64.min()) / (task_attn_64.max() - task_attn_64.min() + 1e-6) | |
| exemplar_attn_64 = (exemplar_attn_64 - exemplar_attn_64.min()) / (exemplar_attn_64.max() - exemplar_attn_64.min() + 1e-6) | |
| task_cross_attention.append(task_attn_64) | |
| instances_cross_attention.append(exemplar_attn_64) | |
| task_cross_attention = torch.cat(task_cross_attention, dim=0) # [n_forward, 1, 64, 64] | |
| task_cross_attention = torch.mean(task_cross_attention, dim=0, keepdim=True) # [1, 1, 64, 64] | |
| instances_cross_attention = torch.cat(instances_cross_attention, dim=1) # [1, n_instance, 64, 64] | |
| assert instances_cross_attention.shape[1] == n_instance, "instances_cross_attention shape mismatch" | |
| attn_stack = [task_cross_attention / 2, instances_cross_attention] | |
| attn_stack = torch.cat(attn_stack, dim=1) | |
| del exemplar_attention_maps, attention_maps, attns, task_attn_64, exemplar_attn_64, latents | |
| del input_ids_, input_ids, attention_mask, encoder_hidden_states, timesteps, noisy_latents | |
| del loca_out, loca_feature_bf_regression, adapted_emb | |
| torch.cuda.empty_cache() | |
| return { | |
| "task_attn_64":task_cross_attention, | |
| "exemplar_attn_64": instances_cross_attention, | |
| "noise_pred":noise_pred, | |
| "noise":noise, | |
| "attn_stack": attn_stack, | |
| "feature_list": feature_list, | |
| } | |
| def common_step(self, batch): | |
| mask = batch["mask_t"].to(torch.float32).to(self.device) | |
| if mask.dim() == 3: | |
| mask = mask.unsqueeze(0) | |
| image_stable = batch["image_stable"] | |
| boxes = batch["boxes"] | |
| input_image = batch["img_enc"] | |
| input_image = input_image.to(self.device) | |
| image_stable = image_stable.to(self.device) | |
| keep_boxes = None | |
| if image_stable.dim() == 4: | |
| image_stable = image_stable.unsqueeze(0) | |
| if input_image.dim() == 4: | |
| input_image = input_image.unsqueeze(0) | |
| # segmentation part | |
| n_frames = mask.shape[1] | |
| masks_pred = [] | |
| for i in range(n_frames): | |
| mask_ = mask[:, i, :, :].unsqueeze(0) # [1, 1, H, W] | |
| mask_ = F.interpolate(mask_.float(), size=(512, 512), mode='nearest') # [1, 1, 512, 512] | |
| mask_ = mask_.to(torch.int64).squeeze(0).detach().to(self.device) # [1, 512, 512] | |
| masks_pred.append(mask_) | |
| del mask_ | |
| # if True: | |
| attns_emb = [] | |
| for i in range(n_frames): | |
| image_stable_prev = image_stable[:, max(0, i-1), :, :, :] | |
| image_stable_after = image_stable[:, min(n_frames-1, i+1), :, :, :] | |
| input_image_curr = input_image[:, i, :, :, :] | |
| mask_ = masks_pred[i].detach() | |
| unique_labels = torch.unique(mask_) # tensor([0, 1, 2, ...]) | |
| boxes_all = [] | |
| for label in unique_labels: | |
| if label.item() == 0: | |
| continue | |
| binary_mask = (mask_[0] == label).to(torch.uint8) # [H, W] | |
| # 找非零点坐标 | |
| y_coords, x_coords = torch.nonzero(binary_mask, as_tuple=True) | |
| if len(x_coords) == 0 or len(y_coords) == 0: | |
| continue | |
| x_min = torch.min(x_coords) | |
| y_min = torch.min(y_coords) | |
| x_max = torch.max(x_coords) | |
| y_max = torch.max(y_coords) | |
| boxes_all.append([x_min.item(), y_min.item(), x_max.item(), y_max.item()]) | |
| boxes_all_t = torch.tensor(boxes_all, dtype=torch.float32).unsqueeze(0) | |
| output_prev = self.forward_boxes(image_stable_prev, boxes_all_t, input_image_curr) | |
| attn_prev = output_prev["exemplar_attn_64"] | |
| feature_list_prev = output_prev["feature_list"] | |
| output_after = self.forward_boxes(image_stable_after, boxes_all_t, input_image_curr) | |
| # attn_stack = output["attn_stack"] | |
| attn_after = output_after["exemplar_attn_64"] # [1, n_instance, 64, 64] | |
| feature_list_after = output_after["feature_list"] # [1, n_channels, res, res] | |
| attn_prev = torch.permute(attn_prev, (1, 0, 2, 3)) # [n_instance, 1, 64, 64] | |
| attn_after = torch.permute(attn_after, (1, 0, 2, 3)) | |
| attn_emb = self.counting_adapter.regressor(attn_prev, feature_list_prev, attn_after, feature_list_after) | |
| attns_emb.append(attn_emb.detach()) | |
| attns_emb = torch.cat(attns_emb, dim=1) # [1, n_instance, 4] | |
| # tracking part | |
| feats = batch["features_t"] | |
| coords = batch["coords_t"] | |
| with torch.no_grad(): | |
| A_pred = self.track_model(coords, feats, attn_feat=attns_emb).detach() | |
| del masks_pred, feats, coords, batch | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| torch.cuda.ipc_collect() | |
| return A_pred | |
| # @profile | |
| def _predict_batch(self, batch): | |
| feats = batch["features_t"].to(self.device) | |
| coords = batch["coords_t"].to(self.device) | |
| timepoints = batch["timepoints_t"].to(self.device) | |
| # Hack that assumes that all parameters of a model are on the same device | |
| device = next(self.track_model.parameters()).device | |
| feats = feats.unsqueeze(0).to(device) | |
| timepoints = timepoints.unsqueeze(0).to(device) | |
| coords = coords.unsqueeze(0).to(device) | |
| # Concat timepoints to coordinates | |
| coords = torch.cat((timepoints.unsqueeze(2).float(), coords), dim=2) | |
| batch["coords_t"] = coords | |
| batch["features_t"] = feats | |
| with torch.no_grad(): | |
| A = self.common_step(batch) | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| A = self.track_model.normalize_output(A, timepoints, coords) | |
| # # Spatially far entries should not influence the causal normalization | |
| # dist = torch.cdist(coords[0, :, 1:], coords[0, :, 1:]) | |
| # invalid = dist > model.config["spatial_pos_cutoff"] | |
| # A[invalid] = -torch.inf | |
| A = A.squeeze(0).detach().cpu().numpy() | |
| del feats, coords, timepoints, batch | |
| return A | |
| # @profile | |
| def predict_windows(self, | |
| windows: List[dict], | |
| features: list, | |
| model, | |
| imgs_enc: Optional[np.ndarray] = None, | |
| imgs_stable: Optional[np.ndarray] = None, | |
| intra_window_weight: float = 0, | |
| delta_t: int = 1, | |
| edge_threshold: float = 0.05, | |
| spatial_dim: int = 3, | |
| progbar_class=tqdm, | |
| ) -> dict: | |
| # first get all objects/coords | |
| time_labels_to_id = dict() | |
| node_properties = list() | |
| max_id = np.sum([len(f.labels) for f in features]) | |
| all_timepoints = np.concatenate([f.timepoints for f in features]) | |
| all_labels = np.concatenate([f.labels for f in features]) | |
| all_coords = np.concatenate([f.coords for f in features]) | |
| all_coords = all_coords[:, -spatial_dim:] | |
| for i, (t, la, c) in enumerate(zip(all_timepoints, all_labels, all_coords)): | |
| time_labels_to_id[(t, la)] = i | |
| node_properties.append( | |
| dict( | |
| id=i, | |
| coords=tuple(c), | |
| time=t, | |
| # index=ix, | |
| label=la, | |
| ) | |
| ) | |
| # create assoc matrix between ids | |
| sp_weights, sp_accum = ( | |
| csr_array((max_id, max_id), dtype=np.float32), | |
| csr_array((max_id, max_id), dtype=np.float32), | |
| ) | |
| tracemalloc.start() | |
| for t in progbar_class( | |
| range(len(windows)), | |
| desc="Computing associations", | |
| ): | |
| # This assumes that the samples in the dataset are ordered by time and start at 0. | |
| batch = windows[t] | |
| timepoints = batch["timepoints"] | |
| labels = batch["labels"] | |
| A = self._predict_batch(batch) | |
| dt = timepoints[None, :] - timepoints[:, None] | |
| time_mask = np.logical_and(dt <= delta_t, dt > 0) | |
| A[~time_mask] = 0 | |
| ii, jj = np.where(A >= edge_threshold) | |
| if len(ii) == 0: | |
| continue | |
| labels_ii = labels[ii] | |
| labels_jj = labels[jj] | |
| ts_ii = timepoints[ii] | |
| ts_jj = timepoints[jj] | |
| nodes_ii = np.array( | |
| tuple(time_labels_to_id[(t, lab)] for t, lab in zip(ts_ii, labels_ii)) | |
| ) | |
| nodes_jj = np.array( | |
| tuple(time_labels_to_id[(t, lab)] for t, lab in zip(ts_jj, labels_jj)) | |
| ) | |
| # weight middle parts higher | |
| t_middle = t + (model.config["window"] - 1) / 2 | |
| ddt = timepoints[:, None] - t_middle * np.ones_like(dt) | |
| window_weight = np.exp(-intra_window_weight * ddt**2) # default is 1 | |
| # window_weight = np.exp(4*A) # smooth max | |
| sp_weights[nodes_ii, nodes_jj] += window_weight[ii, jj] * A[ii, jj] | |
| sp_accum[nodes_ii, nodes_jj] += window_weight[ii, jj] | |
| del batch, A, ii, jj, labels_ii, labels_jj, ts_ii, ts_jj, nodes_ii, nodes_jj, dt, time_mask | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| torch.cuda.ipc_collect() | |
| sp_weights_coo = sp_weights.tocoo() | |
| sp_accum_coo = sp_accum.tocoo() | |
| assert np.allclose(sp_weights_coo.col, sp_accum_coo.col) and np.allclose( | |
| sp_weights_coo.row, sp_accum_coo.row | |
| ) | |
| # Normalize weights by the number of times they were written from different sliding window positions | |
| weights = tuple( | |
| ((i, j), v / a) | |
| for i, j, v, a in zip( | |
| sp_weights_coo.row, | |
| sp_weights_coo.col, | |
| sp_weights_coo.data, | |
| sp_accum_coo.data, | |
| ) | |
| ) | |
| results = dict() | |
| results["nodes"] = node_properties | |
| results["weights"] = weights | |
| return results | |
| def _predict( | |
| self, | |
| imgs: Union[np.ndarray, da.Array], | |
| masks: Union[np.ndarray, da.Array], | |
| imgs_enc: Optional[np.ndarray] = None, | |
| imgs_stable: Optional[np.ndarray] = None, | |
| boxes: Optional[np.ndarray] = None, | |
| edge_threshold: float = 0.05, | |
| n_workers: int = 0, | |
| normalize_imgs: bool = True, | |
| progbar_class=tqdm, | |
| ): | |
| print("Predicting weights for candidate graph") | |
| if normalize_imgs: | |
| if isinstance(imgs, da.Array): | |
| imgs = imgs.map_blocks(normalize) | |
| else: | |
| imgs = normalize(imgs) | |
| self.eval() | |
| features = get_features( | |
| detections=masks, | |
| imgs=imgs, | |
| ndim=self.track_model.config["coord_dim"], | |
| n_workers=n_workers, | |
| progbar_class=progbar_class, | |
| ) | |
| print("Building windows") | |
| windows = build_windows_sd( | |
| features, | |
| imgs_enc=imgs_enc, | |
| imgs_stable=imgs_stable, | |
| boxes=boxes, | |
| imgs=imgs, | |
| masks=masks, | |
| window_size=self.track_model.config["window"], | |
| progbar_class=progbar_class, | |
| ) | |
| print("Predicting windows") | |
| with torch.no_grad(): | |
| predictions = self.predict_windows( | |
| windows=windows, | |
| features=features, | |
| imgs_enc=imgs_enc, | |
| imgs_stable=imgs_stable, | |
| model=self.track_model, | |
| edge_threshold=edge_threshold, | |
| spatial_dim=masks.ndim - 1, | |
| progbar_class=progbar_class, | |
| ) | |
| return predictions | |
| def _track_from_predictions( | |
| self, | |
| predictions, | |
| mode: Literal["greedy_nodiv", "greedy", "ilp"] = "greedy", | |
| use_distance: bool = False, | |
| max_distance: int = 256, | |
| max_neighbors: int = 10, | |
| delta_t: int = 1, | |
| **kwargs, | |
| ): | |
| print("Running greedy tracker") | |
| nodes = predictions["nodes"] | |
| weights = predictions["weights"] | |
| candidate_graph = build_graph( | |
| nodes=nodes, | |
| weights=weights, | |
| use_distance=use_distance, | |
| max_distance=max_distance, | |
| max_neighbors=max_neighbors, | |
| delta_t=delta_t, | |
| ) | |
| if mode == "greedy": | |
| return track_greedy(candidate_graph) | |
| elif mode == "greedy_nodiv": | |
| return track_greedy(candidate_graph, allow_divisions=False) | |
| elif mode == "ilp": | |
| from models.tra_post_model.trackastra.tracking.ilp import track_ilp | |
| return track_ilp(candidate_graph, ilp_config="gt", **kwargs) | |
| else: | |
| raise ValueError(f"Tracking mode {mode} does not exist.") | |
| def track( | |
| self, | |
| file_dir: str, | |
| boxes: Optional[torch.Tensor] = None, | |
| mode: Literal["greedy_nodiv", "greedy", "ilp"] = "greedy", | |
| normalize_imgs: bool = True, | |
| progbar_class=tqdm, | |
| n_workers: int = 0, | |
| dataname: Optional[str] = None, | |
| **kwargs, | |
| ) -> TrackGraph: | |
| """Track objects across time frames. | |
| This method links segmented objects across time frames using the specified | |
| tracking mode. No hyperparameters need to be chosen beyond the tracking mode. | |
| Args: | |
| imgs: Input images of shape (T,(Z),Y,X) (numpy or dask array) | |
| masks: Instance segmentation masks of shape (T,(Z),Y,X). | |
| mode: Tracking mode: | |
| - "greedy_nodiv": Fast greedy linking without division | |
| - "greedy": Fast greedy linking with division | |
| - "ilp": Integer Linear Programming based linking (more accurate but slower) | |
| progbar_class: Progress bar class to use. | |
| n_workers: Number of worker processes for feature extraction. | |
| normalize_imgs: Whether to normalize the images. | |
| **kwargs: Additional arguments passed to tracking algorithm. | |
| Returns: | |
| TrackGraph containing the tracking results. | |
| """ | |
| self.eval() | |
| imgs, imgs_raw, images_stable, tra_imgs, imgs_01, height, width = load_track_images(file_dir) | |
| # tra_imgs = torch.from_numpy(imgs_).float().to(self.device) | |
| imgs_stable = torch.from_numpy(images_stable).float().to(self.device) | |
| imgs_enc = torch.from_numpy(imgs).float().to(self.device) | |
| """get segmentation masks first""" | |
| self.boxes = None | |
| self.adapt_emb = None | |
| masks = [] | |
| for i, (input_image, input_image_stable) in tqdm(enumerate(zip(imgs_enc, imgs_stable))): | |
| input_image = input_image.unsqueeze(0) | |
| input_image_stable = input_image_stable.unsqueeze(0) | |
| if i == 0: | |
| if self.use_box and boxes is not None: | |
| self.boxes = boxes.to(self.device) | |
| else: | |
| self.boxes = None | |
| with torch.no_grad(): | |
| mask = self.forward_sd(input_image_stable, input_image, self.boxes, height=height, width=width) | |
| masks.append(mask) | |
| masks = np.stack(masks, axis=0) # (T, H, W) | |
| # ------------------------- | |
| if not masks.shape == tra_imgs.shape: | |
| raise RuntimeError( | |
| f"Img shape {tra_imgs.shape} and mask shape {masks.shape} do not match." | |
| ) | |
| if not tra_imgs.ndim == self.track_model.config["coord_dim"] + 1: | |
| raise RuntimeError( | |
| f"images should be a sequence of {self.track_model.config['coord_dim']}D images" | |
| ) | |
| predictions = self._predict( | |
| tra_imgs, | |
| masks, | |
| imgs_enc=imgs_enc, | |
| imgs_stable=imgs_stable, | |
| boxes=boxes, | |
| normalize_imgs=normalize_imgs, | |
| progbar_class=progbar_class, | |
| n_workers=n_workers, | |
| ) | |
| track_graph = self._track_from_predictions(predictions, mode=mode, **kwargs) | |
| # ctc_tracks, masks_tracked = graph_to_ctc( | |
| # track_graph, | |
| # masks, | |
| # outdir=f"tracked/{dataname}", | |
| # ) | |
| return track_graph, masks | |
| def inference(data_path, box=None): | |
| if box is not None: | |
| use_box = True | |
| else: | |
| use_box = False | |
| model = TrackingModule(use_box=use_box) | |
| load_msg = model.load_state_dict(torch.load("pretrained/microscopy_matching_tra.pth"), strict=True) | |
| model.move_to_device(torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')) | |
| track_graph, masks = model.track(file_dir=data_path, dataname="inference_sequence") | |
| if not os.path.exists(f"tracked_ours_seg_pred3/"): | |
| os.makedirs(f"tracked_ours_seg_pred3/") | |
| ctc_tracks, masks_tracked = graph_to_ctc( | |
| track_graph, | |
| masks, | |
| outdir=f"tracked_ours_seg_pred3/", | |
| ) | |
| if __name__ == "__main__": | |
| inference(data_path="example_imgs/2D+Time/Fluo-N2DL-HeLa/train/Fluo-N2DL-HeLa/02") | |