import os import pprint from typing import Any, List, Optional import argparse from huggingface_hub import hf_hub_download import pyrallis from pytorch_lightning.utilities.types import STEP_OUTPUT import torch import os from PIL import Image import numpy as np import tifffile import skimage.io as io from config import RunConfig from _utils import attn_utils_new as attn_utils from _utils.attn_utils import AttentionStore from _utils.misc_helper import * from torch.autograd import Variable import itertools from accelerate import Accelerator import torch.nn.functional as F from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline from tqdm import tqdm import torch.nn as nn import matplotlib.pyplot as plt import cv2 import warnings warnings.filterwarnings("ignore", category=UserWarning) import pytorch_lightning as pl from _utils.load_models import load_stable_diffusion_model from models.model import Counting_with_SD_features_track as Counting from models.enc_model.loca_args import get_argparser as loca_get_argparser from models.enc_model.loca import build_model as build_loca_model import time from _utils.seg_eval import * from models.tra_post_model.trackastra.model import Trackastra from models.tra_post_model.trackastra.model import TrackingTransformer from models.tra_post_model.trackastra.utils import ( blockwise_causal_norm, blockwise_sum, normalize, ) from models.tra_post_model.trackastra.data import build_windows_sd, get_features, load_tiff_timeseries from models.tra_post_model.trackastra.tracking import TrackGraph, build_graph, track_greedy, graph_to_ctc from _utils.track_args import parse_train_args as get_track_args import torchvision.transforms as T from pathlib import Path import dask.array as da from typing import Dict, List, Optional, Union, Literal from scipy.sparse import SparseEfficiencyWarning, csr_array import tracemalloc import gc # from memory_profiler import profile from _utils.load_track_data import load_track_images SCALE = 1 def get_instance_boxes(mask): # Convert to int64 if needed if mask.dtype != torch.long: mask = mask.to(torch.long) boxes = [] instance_ids = torch.unique(mask) instance_ids = instance_ids[instance_ids != 0] # skip background for inst_id in instance_ids: inst_mask = mask == inst_id y_indices, x_indices = torch.where(inst_mask) if len(x_indices) == 0 or len(y_indices) == 0: continue x_min = torch.min(x_indices).item() x_max = torch.max(x_indices).item() y_min = torch.min(y_indices).item() y_max = torch.max(y_indices).item() boxes.append([x_min, y_min, x_max, y_max]) boxes = torch.tensor(boxes, dtype=torch.float32) return boxes class TrackingModule(pl.LightningModule): def __init__(self, use_box=False): super().__init__() self.use_box = use_box self.config = RunConfig() # config for stable diffusion self.initialize_model() def initialize_model(self): # load loca model loca_args = loca_get_argparser().parse_args() self.loca_model = build_loca_model(loca_args) # weights = torch.load("ckpt/loca_few_shot.pt")["model"] # weights = {k.replace("module","") : v for k, v in weights.items()} # self.loca_model.load_state_dict(weights, strict=False) # del weights self.counting_adapter = Counting(scale_factor=SCALE) # if os.path.isfile(self.args.adapter_weight): # adapter_weight = torch.load(self.args.adapter_weight,map_location=torch.device('cpu')) # self.counting_adapter.load_state_dict(adapter_weight, strict=False) ### load stable diffusion and its controller self.stable = load_stable_diffusion_model(config=self.config) self.noise_scheduler = self.stable.scheduler self.controller = AttentionStore(max_size=64) attn_utils.register_attention_control(self.stable, self.controller) attn_utils.register_hier_output(self.stable) ##### initialize token_emb ##### placeholder_token = "" self.task_token = "repetitive objects" # Add the placeholder token in tokenizer num_added_tokens = self.stable.tokenizer.add_tokens(placeholder_token) if num_added_tokens == 0: raise ValueError( f"The tokenizer already contains the token {placeholder_token}. Please pass a different" " `placeholder_token` that is not already in the tokenizer." ) try: # task_embed_from_pretrain = torch.load("pretrained/task_embed.pth") task_embed_from_pretrain = hf_hub_download( repo_id="phoebe777777/111", filename="task_embed.pth", token=None, force_download=False ) placeholder_token_id = self.stable.tokenizer.convert_tokens_to_ids(placeholder_token) self.stable.text_encoder.resize_token_embeddings(len(self.stable.tokenizer)) token_embeds = self.stable.text_encoder.get_input_embeddings().weight.data token_embeds[placeholder_token_id] = task_embed_from_pretrain except: #@title Get token ids for our placeholder and initializer token. This code block will complain if initializer string is not a single token # Convert the initializer_token, placeholder_token to ids initializer_token = "track" token_ids = self.stable.tokenizer.encode(initializer_token, add_special_tokens=False) # Check if initializer_token is a single token or a sequence of tokens if len(token_ids) > 1: raise ValueError("The initializer token must be a single token.") initializer_token_id = token_ids[0] placeholder_token_id = self.stable.tokenizer.convert_tokens_to_ids(placeholder_token) self.stable.text_encoder.resize_token_embeddings(len(self.stable.tokenizer)) token_embeds = self.stable.text_encoder.get_input_embeddings().weight.data token_embeds[placeholder_token_id] = token_embeds[initializer_token_id] # others self.placeholder_token = placeholder_token self.placeholder_token_id = placeholder_token_id # tracking model # fpath = Path("models/tra_post_model/trackastra/.models/general_2d/model.pt") fpath = Path("_utils/config.yaml") args_ = get_track_args() model = TrackingTransformer.from_cfg( cfg_path=fpath, args=args_, ) # model = TrackingTransformer.from_folder( # Path(*fpath.parts[:-1]), # args=args_, # checkpoint_path=Path(*fpath.parts[-1:]), # ) self.track_model = model self.track_args = args_ def move_to_device(self, device): self.stable.to(device) # if self.loca_model is not None and self.counting_adapter is not None: # self.loca_model.to(device) # self.counting_adapter.to(device) self.counting_adapter.to(device) # self.dino.to(device) self.loca_model.to(device) self.track_model.to(device) self.to(device) def on_train_start(self) -> None: device = self.device dtype = self.dtype self.stable.to(device,dtype) def on_validation_start(self) -> None: device = self.device dtype = self.dtype self.stable.to(device,dtype) def forward(self, data): input_image_stable = data["image_stable"] boxes = data["boxes"] input_image = data["img_enc"] mask = data["mask"] latents = self.stable.vae.encode(input_image_stable).latent_dist.sample().detach() latents = latents * 0.18215 # Sample noise that we'll add to the latents noise = torch.randn_like(latents) bsz = latents.shape[0] timesteps = torch.tensor([20], device=latents.device).long() noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps) input_ids_ = self.stable.tokenizer( self.placeholder_token, # "object", padding="max_length", truncation=True, max_length=self.stable.tokenizer.model_max_length, return_tensors="pt", ) input_ids = input_ids_["input_ids"].to(self.device) attention_mask = input_ids_["attention_mask"].to(self.device) encoder_hidden_states = self.stable.text_encoder(input_ids, attention_mask)[0] encoder_hidden_states = encoder_hidden_states.repeat(bsz, 1, 1) time1 = time.time() input_image = input_image.to(self.device) boxes = boxes.to(self.device) loca_out = self.loca_model.forward_before_reg(input_image, boxes) loca_feature_bf_regression = loca_out["feature_bf_regression"] # time2 = time.time() task_loc_idx = torch.nonzero(input_ids == self.placeholder_token_id) adapted_emb = self.counting_adapter.adapter(loca_feature_bf_regression, boxes) # shape [1, 768] if task_loc_idx.shape[0] == 0: encoder_hidden_states[0,2,:] = adapted_emb.squeeze() # 放在task prompt下一位 else: encoder_hidden_states[:,task_loc_idx[0, 1]+1,:] = adapted_emb.squeeze() # 放在task prompt下一位 # Predict the noise residual noise_pred, feature_list = self.stable.unet(noisy_latents, timesteps, encoder_hidden_states) time3 = time.time() noise_pred = noise_pred.sample attention_store = self.controller.attention_store # print(time2-time1, time3-time2) attention_maps = [] exemplar_attention_maps = [] cross_self_task_attn_maps = [] cross_self_exe_attn_maps = [] # only use 64x64 self-attention self_attn_aggregate = attn_utils.aggregate_attention( # [res, res, 4096] prompts=[self.config.prompt for i in range(bsz)], # 这里要改么 attention_store=self.controller, res=64, from_where=("up", "down"), is_cross=False, select=0 ) self_attn_aggregate32 = attn_utils.aggregate_attention( # [res, res, 4096] prompts=[self.config.prompt for i in range(bsz)], # 这里要改么 attention_store=self.controller, res=32, from_where=("up", "down"), is_cross=False, select=0 ) self_attn_aggregate16 = attn_utils.aggregate_attention( # [res, res, 4096] prompts=[self.config.prompt for i in range(bsz)], # 这里要改么 attention_store=self.controller, res=16, from_where=("up", "down"), is_cross=False, select=0 ) # cross attention for res in [32, 16]: attn_aggregate = attn_utils.aggregate_attention( # [res, res, 77] prompts=[self.config.prompt for i in range(bsz)], # 这里要改么 attention_store=self.controller, res=res, from_where=("up", "down"), is_cross=True, select=0 ) task_attn_ = attn_aggregate[:, :, 1].unsqueeze(0).unsqueeze(0) # [1, 1, res, res] attention_maps.append(task_attn_) exemplar_attns = attn_aggregate[:, :, 2].unsqueeze(0).unsqueeze(0) # 取exemplar的attn exemplar_attention_maps.append(exemplar_attns) scale_factors = [(64 // attention_maps[i].shape[-1]) for i in range(len(attention_maps))] attns = torch.cat([F.interpolate(attention_maps[i_], scale_factor=scale_factors[i_], mode="bilinear") for i_ in range(len(attention_maps))]) task_attn_64 = torch.mean(attns, dim=0, keepdim=True) scale_factors = [(64 // exemplar_attention_maps[i].shape[-1]) for i in range(len(exemplar_attention_maps))] attns = torch.cat([F.interpolate(exemplar_attention_maps[i_], scale_factor=scale_factors[i_], mode="bilinear") for i_ in range(len(exemplar_attention_maps))]) exemplar_attn_64 = torch.mean(attns, dim=0, keepdim=True) cross_self_task_attn = attn_utils.self_cross_attn(self_attn_aggregate, task_attn_64) cross_self_exe_attn = attn_utils.self_cross_attn(self_attn_aggregate, exemplar_attn_64) cross_self_task_attn_maps.append(cross_self_task_attn) cross_self_exe_attn_maps.append(cross_self_exe_attn) task_attn_64 = (task_attn_64 - task_attn_64.min()) / (task_attn_64.max() - task_attn_64.min() + 1e-6) cross_self_task_attn = (cross_self_task_attn - cross_self_task_attn.min()) / (cross_self_task_attn.max() - cross_self_task_attn.min() + 1e-6) exemplar_attn_64 = (exemplar_attn_64 - exemplar_attn_64.min()) / (exemplar_attn_64.max() - exemplar_attn_64.min() + 1e-6) cross_self_exe_attn = (cross_self_exe_attn - cross_self_exe_attn.min()) / (cross_self_exe_attn.max() - cross_self_exe_attn.min() + 1e-6) attn_stack = [task_attn_64 / 2, cross_self_task_attn / 2, exemplar_attn_64, cross_self_exe_attn] attn_stack = torch.cat(attn_stack, dim=1) attn_after_new_regressor, loss = self.counting_adapter.regressor(input_image, attn_stack, feature_list, mask.cpu().numpy(), training=False) # 直接用自己的 return { "attn_after_new_regressor":attn_after_new_regressor, "task_attn_64":task_attn_64, "cross_self_task_attn":cross_self_task_attn, "exemplar_attn_64": exemplar_attn_64, "cross_self_exe_attn": cross_self_exe_attn, "noise_pred":noise_pred, "noise":noise, "self_attn_aggregate":self_attn_aggregate, "self_attn_aggregate32":self_attn_aggregate32, "self_attn_aggregate16":self_attn_aggregate16, "loss": loss } def forward_sd(self, input_image_stable, input_image, boxes, height, width, mask=None): input_image_stable = input_image_stable.to(self.device) # density = data["density"] if boxes is not None: boxes = boxes.to(self.device) input_image = input_image.to(self.device) if mask is not None: mask = mask.to(self.device) else: mask = torch.zeros((input_image.shape[0], 1, input_image.shape[2], input_image.shape[3])).to(self.device) latents = self.stable.vae.encode(input_image_stable).latent_dist.sample().detach() latents = latents * 0.18215 # Sample noise that we'll add to the latents noise = torch.randn_like(latents) bsz = latents.shape[0] timesteps = torch.tensor([20], device=latents.device).long() noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps) input_ids_ = self.stable.tokenizer( self.placeholder_token + " " + self.task_token, # "object", padding="max_length", truncation=True, max_length=self.stable.tokenizer.model_max_length, return_tensors="pt", ) input_ids = input_ids_["input_ids"].to(self.device) attention_mask = input_ids_["attention_mask"].to(self.device) encoder_hidden_states = self.stable.text_encoder(input_ids, attention_mask)[0] encoder_hidden_states = encoder_hidden_states.repeat(bsz, 1, 1) if boxes is not None and not self.training: if self.adapt_emb is None: loca_out_ = self.loca_model.forward_before_reg(input_image, boxes) loca_feature_bf_regression_ = loca_out_["feature_bf_regression"] adapted_emb = self.counting_adapter.adapter(loca_feature_bf_regression_, boxes) # shape [1, 768] else: adapted_emb = self.adapt_emb.to(self.device) task_loc_idx = torch.nonzero(input_ids == self.placeholder_token_id) if task_loc_idx.shape[0] == 0: encoder_hidden_states[0,5,:] = adapted_emb.squeeze() # 放在task prompt下一位 else: encoder_hidden_states[:,task_loc_idx[0, 1]+4,:] = adapted_emb.squeeze() # 放在task prompt下一位 # Predict the noise residual noise_pred, feature_list = self.stable.unet(noisy_latents, timesteps, encoder_hidden_states) noise_pred = noise_pred.sample attention_store = self.controller.attention_store attention_maps = [] exemplar_attention_maps = [] exemplar_attention_maps1 = [] exemplar_attention_maps2 = [] exemplar_attention_maps3 = [] exemplar_attention_maps4 = [] cross_self_task_attn_maps = [] cross_self_exe_attn_maps = [] # only use 64x64 self-attention self_attn_aggregate = attn_utils.aggregate_attention( # [res, res, 4096] prompts=[self.config.prompt for i in range(bsz)], # 这里要改么 attention_store=self.controller, res=64, from_where=("up", "down"), is_cross=False, select=0 ) # cross attention for res in [32, 16]: attn_aggregate = attn_utils.aggregate_attention( # [res, res, 77] prompts=[self.config.prompt for i in range(bsz)], # 这里要改么 attention_store=self.controller, res=res, from_where=("up", "down"), is_cross=True, select=0 ) task_attn_ = attn_aggregate[:, :, 1].unsqueeze(0).unsqueeze(0) # [1, 1, res, res] attention_maps.append(task_attn_) # if self.boxes is not None and not self.training: exemplar_attns1 = attn_aggregate[:, :, 2].unsqueeze(0).unsqueeze(0) # 取exemplar的attn exemplar_attention_maps1.append(exemplar_attns1) exemplar_attns2 = attn_aggregate[:, :, 3].unsqueeze(0).unsqueeze(0) # 取exemplar的attn exemplar_attention_maps2.append(exemplar_attns2) exemplar_attns3 = attn_aggregate[:, :, 4].unsqueeze(0).unsqueeze(0) # 取exemplar的attn exemplar_attention_maps3.append(exemplar_attns3) exemplar_attns4 = attn_aggregate[:, :, 5].unsqueeze(0).unsqueeze(0) # 取exemplar的attn exemplar_attention_maps4.append(exemplar_attns4) scale_factors = [(64 // attention_maps[i].shape[-1]) for i in range(len(attention_maps))] attns = torch.cat([F.interpolate(attention_maps[i_], scale_factor=scale_factors[i_], mode="bilinear") for i_ in range(len(attention_maps))]) task_attn_64 = torch.mean(attns, dim=0, keepdim=True) cross_self_task_attn = attn_utils.self_cross_attn(self_attn_aggregate, task_attn_64) cross_self_task_attn_maps.append(cross_self_task_attn) # if not self.training: scale_factors = [(64 // exemplar_attention_maps1[i].shape[-1]) for i in range(len(exemplar_attention_maps1))] attns = torch.cat([F.interpolate(exemplar_attention_maps1[i_], scale_factor=scale_factors[i_], mode="bilinear") for i_ in range(len(exemplar_attention_maps1))]) exemplar_attn_64_1 = torch.mean(attns, dim=0, keepdim=True) scale_factors = [(64 // exemplar_attention_maps2[i].shape[-1]) for i in range(len(exemplar_attention_maps2))] attns = torch.cat([F.interpolate(exemplar_attention_maps2[i_], scale_factor=scale_factors[i_], mode="bilinear") for i_ in range(len(exemplar_attention_maps2))]) exemplar_attn_64_2 = torch.mean(attns, dim=0, keepdim=True) scale_factors = [(64 // exemplar_attention_maps3[i].shape[-1]) for i in range(len(exemplar_attention_maps3))] attns = torch.cat([F.interpolate(exemplar_attention_maps3[i_], scale_factor=scale_factors[i_], mode="bilinear") for i_ in range(len(exemplar_attention_maps3))]) exemplar_attn_64_3 = torch.mean(attns, dim=0, keepdim=True) if boxes is not None: scale_factors = [(64 // exemplar_attention_maps4[i].shape[-1]) for i in range(len(exemplar_attention_maps4))] attns = torch.cat([F.interpolate(exemplar_attention_maps4[i_], scale_factor=scale_factors[i_], mode="bilinear") for i_ in range(len(exemplar_attention_maps4))]) exemplar_attn_64_4 = torch.mean(attns, dim=0, keepdim=True) exes = [] cross_exes = [] cross_self_exe_attn1 = attn_utils.self_cross_attn(self_attn_aggregate, exemplar_attn_64_1) cross_self_exe_attn2 = attn_utils.self_cross_attn(self_attn_aggregate, exemplar_attn_64_2) cross_self_exe_attn3 = attn_utils.self_cross_attn(self_attn_aggregate, exemplar_attn_64_3) # # average exemplar_attn_64_1 = (exemplar_attn_64_1 - exemplar_attn_64_1.min()) / (exemplar_attn_64_1.max() - exemplar_attn_64_1.min() + 1e-6) exemplar_attn_64_2 = (exemplar_attn_64_2 - exemplar_attn_64_2.min()) / (exemplar_attn_64_2.max() - exemplar_attn_64_2.min() + 1e-6) exemplar_attn_64_3 = (exemplar_attn_64_3 - exemplar_attn_64_3.min()) / (exemplar_attn_64_3.max() - exemplar_attn_64_3.min() + 1e-6) cross_self_exe_attn1 = (cross_self_exe_attn1 - cross_self_exe_attn1.min()) / (cross_self_exe_attn1.max() - cross_self_exe_attn1.min() + 1e-6) cross_self_exe_attn2 = (cross_self_exe_attn2 - cross_self_exe_attn2.min()) / (cross_self_exe_attn2.max() - cross_self_exe_attn2.min() + 1e-6) cross_self_exe_attn3 = (cross_self_exe_attn3 - cross_self_exe_attn3.min()) / (cross_self_exe_attn3.max() - cross_self_exe_attn3.min() + 1e-6) exes = [exemplar_attn_64_1, exemplar_attn_64_2, exemplar_attn_64_3] cross_exes = [cross_self_exe_attn1, cross_self_exe_attn2, cross_self_exe_attn3] if boxes is not None: cross_self_exe_attn4 = attn_utils.self_cross_attn(self_attn_aggregate, exemplar_attn_64_4) exemplar_attn_64_4 = (exemplar_attn_64_4 - exemplar_attn_64_4.min()) / (exemplar_attn_64_4.max() - exemplar_attn_64_4.min() + 1e-6) cross_self_exe_attn4 = (cross_self_exe_attn4 - cross_self_exe_attn4.min()) / (cross_self_exe_attn4.max() - cross_self_exe_attn4.min() + 1e-6) exes.append(exemplar_attn_64_4) cross_exes.append(cross_self_exe_attn4) exemplar_attn_64 = sum(exes) / len(exes) cross_self_exe_attn = sum(cross_exes) / len(cross_exes) if self.use_box: attn_stack = [task_attn_64 / 2, cross_self_task_attn / 2, exemplar_attn_64, cross_self_exe_attn] else: attn_stack = [exemplar_attn_64 / 2, cross_self_exe_attn / 2, exemplar_attn_64, cross_self_exe_attn] attn_stack = torch.cat(attn_stack, dim=1) attn_after_new_regressor, loss, _ = self.counting_adapter.regressor.forward_seg(input_image, attn_stack, feature_list, mask.cpu().numpy(), self.training) if not self.training: pred_mask = attn_after_new_regressor.detach().cpu() pred_boxes = get_instance_boxes(pred_mask.squeeze()) self.boxes = pred_boxes.unsqueeze(0) if pred_boxes.shape[0] == 0: print("No instances detected in the predicted mask.") self.adapt_emb = adapted_emb.detach().cpu() # reuse emb else: pred_boxes = pred_boxes.unsqueeze(0).to(self.device) loca_out_ = self.loca_model.forward_before_reg(input_image, pred_boxes) loca_feature_bf_regression_ = loca_out_["feature_bf_regression"] adapted_emb_ = self.counting_adapter.adapter(loca_feature_bf_regression_, pred_boxes) # shape [1, 768] self.adapt_emb = adapted_emb_.detach().cpu() # resize to original image size mask_np = attn_after_new_regressor.squeeze().detach().cpu().numpy() mask_resized = cv2.resize(mask_np, (width, height), interpolation=cv2.INTER_NEAREST) return mask_resized def forward_boxes(self, input_image_stable, boxes, input_image): latents = self.stable.vae.encode(input_image_stable).latent_dist.sample().detach() latents = latents * 0.18215 # Sample noise that we'll add to the latents noise = torch.randn_like(latents) bsz = latents.shape[0] timesteps = torch.tensor([20], device=latents.device).long() noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps) input_ids_ = self.stable.tokenizer( self.placeholder_token, # "object", padding="max_length", truncation=True, max_length=self.stable.tokenizer.model_max_length, return_tensors="pt", ) input_ids = input_ids_["input_ids"].to(self.device) attention_mask = input_ids_["attention_mask"].to(self.device) encoder_hidden_states = self.stable.text_encoder(input_ids, attention_mask)[0] encoder_hidden_states = encoder_hidden_states.repeat(bsz, 1, 1) time1 = time.time() input_image = input_image.to(self.device) boxes = boxes.to(self.device) loca_out = self.loca_model.forward_before_reg(input_image, boxes) loca_feature_bf_regression = loca_out["feature_bf_regression"] # time2 = time.time() task_loc_idx = torch.nonzero(input_ids == self.placeholder_token_id) adapted_emb = self.counting_adapter.adapter.forward_boxes(loca_feature_bf_regression, boxes) # shape [n_instance, 768] n_instance = adapted_emb.shape[0] n_forward = int(np.ceil(n_instance / 74)) # in total 75 prompts including 1 task prompt and 74 object prompts? task_cross_attention = [] instances_cross_attention = [] for n in range(n_forward): len_ = min(74, n_instance - n * 74) encoder_hidden_states[:,(task_loc_idx[0, 1]+1):(task_loc_idx[0, 1]+1+len_),:] = adapted_emb[n*74:n*74+len_].squeeze() # 放在task prompt下一位 # encoder_hidden_states: # [bsz, 77, 768], 其中第1位是task prompt的embedding, 第二位开始可以是object prompt的embedding, 最后一位应该保留原始embedding # Predict the noise residual noise_pred, feature_list = self.stable.unet(noisy_latents, timesteps, encoder_hidden_states) noise_pred = noise_pred.sample attention_maps = [] exemplar_attention_maps = [] # cross attention for res in [32, 16]: attn_aggregate = attn_utils.aggregate_attention( # [res, res, 77] prompts=[self.config.prompt for i in range(bsz)], # 这里要改么 attention_store=self.controller, res=res, from_where=("up", "down"), is_cross=True, select=0 ) task_attn_ = attn_aggregate[:, :, 1].unsqueeze(0).unsqueeze(0) # [1, 1, res, res] attention_maps.append(task_attn_) try: exemplar_attns = attn_aggregate[:, :, (task_loc_idx[0, 1]+1):(task_loc_idx[0, 1]+1+len_)].unsqueeze(0) # 取exemplar的attn except: print(n_instance, len_) exemplar_attns = torch.permute(exemplar_attns, (0, 3, 1, 2)) # [1, len_, res, res] exemplar_attention_maps.append(exemplar_attns) scale_factors = [(64 // attention_maps[i].shape[-1]) for i in range(len(attention_maps))] attns = torch.cat([F.interpolate(attention_maps[i_], scale_factor=scale_factors[i_], mode="bilinear") for i_ in range(len(attention_maps))]) task_attn_64 = torch.mean(attns, dim=0, keepdim=True) try: scale_factors = [(64 // exemplar_attention_maps[i].shape[-1]) for i in range(len(exemplar_attention_maps))] attns = torch.cat([F.interpolate(exemplar_attention_maps[i_], scale_factor=scale_factors[i_], mode="bilinear") for i_ in range(len(exemplar_attention_maps))]) except: print("exemplar_attention_maps shape mismatch, n_instance: {}, len_: {}".format(n_instance, len_)) print(exemplar_attention_maps[0].shape) print(exemplar_attention_maps[1].shape) print(exemplar_attention_maps[2].shape) exemplar_attn_64 = torch.mean(attns, dim=0, keepdim=True) task_attn_64 = (task_attn_64 - task_attn_64.min()) / (task_attn_64.max() - task_attn_64.min() + 1e-6) exemplar_attn_64 = (exemplar_attn_64 - exemplar_attn_64.min()) / (exemplar_attn_64.max() - exemplar_attn_64.min() + 1e-6) task_cross_attention.append(task_attn_64) instances_cross_attention.append(exemplar_attn_64) task_cross_attention = torch.cat(task_cross_attention, dim=0) # [n_forward, 1, 64, 64] task_cross_attention = torch.mean(task_cross_attention, dim=0, keepdim=True) # [1, 1, 64, 64] instances_cross_attention = torch.cat(instances_cross_attention, dim=1) # [1, n_instance, 64, 64] assert instances_cross_attention.shape[1] == n_instance, "instances_cross_attention shape mismatch" attn_stack = [task_cross_attention / 2, instances_cross_attention] attn_stack = torch.cat(attn_stack, dim=1) del exemplar_attention_maps, attention_maps, attns, task_attn_64, exemplar_attn_64, latents del input_ids_, input_ids, attention_mask, encoder_hidden_states, timesteps, noisy_latents del loca_out, loca_feature_bf_regression, adapted_emb torch.cuda.empty_cache() return { "task_attn_64":task_cross_attention, "exemplar_attn_64": instances_cross_attention, "noise_pred":noise_pred, "noise":noise, "attn_stack": attn_stack, "feature_list": feature_list, } def common_step(self, batch): mask = batch["mask_t"].to(torch.float32).to(self.device) if mask.dim() == 3: mask = mask.unsqueeze(0) image_stable = batch["image_stable"] boxes = batch["boxes"] input_image = batch["img_enc"] input_image = input_image.to(self.device) image_stable = image_stable.to(self.device) keep_boxes = None if image_stable.dim() == 4: image_stable = image_stable.unsqueeze(0) if input_image.dim() == 4: input_image = input_image.unsqueeze(0) # segmentation part n_frames = mask.shape[1] masks_pred = [] for i in range(n_frames): mask_ = mask[:, i, :, :].unsqueeze(0) # [1, 1, H, W] mask_ = F.interpolate(mask_.float(), size=(512, 512), mode='nearest') # [1, 1, 512, 512] mask_ = mask_.to(torch.int64).squeeze(0).detach().to(self.device) # [1, 512, 512] masks_pred.append(mask_) del mask_ # if True: attns_emb = [] for i in range(n_frames): image_stable_prev = image_stable[:, max(0, i-1), :, :, :] image_stable_after = image_stable[:, min(n_frames-1, i+1), :, :, :] input_image_curr = input_image[:, i, :, :, :] mask_ = masks_pred[i].detach() unique_labels = torch.unique(mask_) # tensor([0, 1, 2, ...]) boxes_all = [] for label in unique_labels: if label.item() == 0: continue binary_mask = (mask_[0] == label).to(torch.uint8) # [H, W] # 找非零点坐标 y_coords, x_coords = torch.nonzero(binary_mask, as_tuple=True) if len(x_coords) == 0 or len(y_coords) == 0: continue x_min = torch.min(x_coords) y_min = torch.min(y_coords) x_max = torch.max(x_coords) y_max = torch.max(y_coords) boxes_all.append([x_min.item(), y_min.item(), x_max.item(), y_max.item()]) boxes_all_t = torch.tensor(boxes_all, dtype=torch.float32).unsqueeze(0) output_prev = self.forward_boxes(image_stable_prev, boxes_all_t, input_image_curr) attn_prev = output_prev["exemplar_attn_64"] feature_list_prev = output_prev["feature_list"] output_after = self.forward_boxes(image_stable_after, boxes_all_t, input_image_curr) # attn_stack = output["attn_stack"] attn_after = output_after["exemplar_attn_64"] # [1, n_instance, 64, 64] feature_list_after = output_after["feature_list"] # [1, n_channels, res, res] attn_prev = torch.permute(attn_prev, (1, 0, 2, 3)) # [n_instance, 1, 64, 64] attn_after = torch.permute(attn_after, (1, 0, 2, 3)) attn_emb = self.counting_adapter.regressor(attn_prev, feature_list_prev, attn_after, feature_list_after) attns_emb.append(attn_emb.detach()) attns_emb = torch.cat(attns_emb, dim=1) # [1, n_instance, 4] # tracking part feats = batch["features_t"] coords = batch["coords_t"] with torch.no_grad(): A_pred = self.track_model(coords, feats, attn_feat=attns_emb).detach() del masks_pred, feats, coords, batch gc.collect() torch.cuda.empty_cache() torch.cuda.ipc_collect() return A_pred # @profile def _predict_batch(self, batch): feats = batch["features_t"].to(self.device) coords = batch["coords_t"].to(self.device) timepoints = batch["timepoints_t"].to(self.device) # Hack that assumes that all parameters of a model are on the same device device = next(self.track_model.parameters()).device feats = feats.unsqueeze(0).to(device) timepoints = timepoints.unsqueeze(0).to(device) coords = coords.unsqueeze(0).to(device) # Concat timepoints to coordinates coords = torch.cat((timepoints.unsqueeze(2).float(), coords), dim=2) batch["coords_t"] = coords batch["features_t"] = feats with torch.no_grad(): A = self.common_step(batch) torch.cuda.empty_cache() gc.collect() A = self.track_model.normalize_output(A, timepoints, coords) # # Spatially far entries should not influence the causal normalization # dist = torch.cdist(coords[0, :, 1:], coords[0, :, 1:]) # invalid = dist > model.config["spatial_pos_cutoff"] # A[invalid] = -torch.inf A = A.squeeze(0).detach().cpu().numpy() del feats, coords, timepoints, batch return A # @profile def predict_windows(self, windows: List[dict], features: list, model, imgs_enc: Optional[np.ndarray] = None, imgs_stable: Optional[np.ndarray] = None, intra_window_weight: float = 0, delta_t: int = 1, edge_threshold: float = 0.05, spatial_dim: int = 3, progbar_class=tqdm, ) -> dict: # first get all objects/coords time_labels_to_id = dict() node_properties = list() max_id = np.sum([len(f.labels) for f in features]) all_timepoints = np.concatenate([f.timepoints for f in features]) all_labels = np.concatenate([f.labels for f in features]) all_coords = np.concatenate([f.coords for f in features]) all_coords = all_coords[:, -spatial_dim:] for i, (t, la, c) in enumerate(zip(all_timepoints, all_labels, all_coords)): time_labels_to_id[(t, la)] = i node_properties.append( dict( id=i, coords=tuple(c), time=t, # index=ix, label=la, ) ) # create assoc matrix between ids sp_weights, sp_accum = ( csr_array((max_id, max_id), dtype=np.float32), csr_array((max_id, max_id), dtype=np.float32), ) tracemalloc.start() for t in progbar_class( range(len(windows)), desc="Computing associations", ): # This assumes that the samples in the dataset are ordered by time and start at 0. batch = windows[t] timepoints = batch["timepoints"] labels = batch["labels"] A = self._predict_batch(batch) dt = timepoints[None, :] - timepoints[:, None] time_mask = np.logical_and(dt <= delta_t, dt > 0) A[~time_mask] = 0 ii, jj = np.where(A >= edge_threshold) if len(ii) == 0: continue labels_ii = labels[ii] labels_jj = labels[jj] ts_ii = timepoints[ii] ts_jj = timepoints[jj] nodes_ii = np.array( tuple(time_labels_to_id[(t, lab)] for t, lab in zip(ts_ii, labels_ii)) ) nodes_jj = np.array( tuple(time_labels_to_id[(t, lab)] for t, lab in zip(ts_jj, labels_jj)) ) # weight middle parts higher t_middle = t + (model.config["window"] - 1) / 2 ddt = timepoints[:, None] - t_middle * np.ones_like(dt) window_weight = np.exp(-intra_window_weight * ddt**2) # default is 1 # window_weight = np.exp(4*A) # smooth max sp_weights[nodes_ii, nodes_jj] += window_weight[ii, jj] * A[ii, jj] sp_accum[nodes_ii, nodes_jj] += window_weight[ii, jj] del batch, A, ii, jj, labels_ii, labels_jj, ts_ii, ts_jj, nodes_ii, nodes_jj, dt, time_mask gc.collect() torch.cuda.empty_cache() torch.cuda.ipc_collect() sp_weights_coo = sp_weights.tocoo() sp_accum_coo = sp_accum.tocoo() assert np.allclose(sp_weights_coo.col, sp_accum_coo.col) and np.allclose( sp_weights_coo.row, sp_accum_coo.row ) # Normalize weights by the number of times they were written from different sliding window positions weights = tuple( ((i, j), v / a) for i, j, v, a in zip( sp_weights_coo.row, sp_weights_coo.col, sp_weights_coo.data, sp_accum_coo.data, ) ) results = dict() results["nodes"] = node_properties results["weights"] = weights return results def _predict( self, imgs: Union[np.ndarray, da.Array], masks: Union[np.ndarray, da.Array], imgs_enc: Optional[np.ndarray] = None, imgs_stable: Optional[np.ndarray] = None, boxes: Optional[np.ndarray] = None, edge_threshold: float = 0.05, n_workers: int = 0, normalize_imgs: bool = True, progbar_class=tqdm, ): print("Predicting weights for candidate graph") if normalize_imgs: if isinstance(imgs, da.Array): imgs = imgs.map_blocks(normalize) else: imgs = normalize(imgs) self.eval() features = get_features( detections=masks, imgs=imgs, ndim=self.track_model.config["coord_dim"], n_workers=n_workers, progbar_class=progbar_class, ) print("Building windows") windows = build_windows_sd( features, imgs_enc=imgs_enc, imgs_stable=imgs_stable, boxes=boxes, imgs=imgs, masks=masks, window_size=self.track_model.config["window"], progbar_class=progbar_class, ) print("Predicting windows") with torch.no_grad(): predictions = self.predict_windows( windows=windows, features=features, imgs_enc=imgs_enc, imgs_stable=imgs_stable, model=self.track_model, edge_threshold=edge_threshold, spatial_dim=masks.ndim - 1, progbar_class=progbar_class, ) return predictions def _track_from_predictions( self, predictions, mode: Literal["greedy_nodiv", "greedy", "ilp"] = "greedy", use_distance: bool = False, max_distance: int = 256, max_neighbors: int = 10, delta_t: int = 1, **kwargs, ): print("Running greedy tracker") nodes = predictions["nodes"] weights = predictions["weights"] candidate_graph = build_graph( nodes=nodes, weights=weights, use_distance=use_distance, max_distance=max_distance, max_neighbors=max_neighbors, delta_t=delta_t, ) if mode == "greedy": return track_greedy(candidate_graph) elif mode == "greedy_nodiv": return track_greedy(candidate_graph, allow_divisions=False) elif mode == "ilp": from models.tra_post_model.trackastra.tracking.ilp import track_ilp return track_ilp(candidate_graph, ilp_config="gt", **kwargs) else: raise ValueError(f"Tracking mode {mode} does not exist.") def track( self, file_dir: str, boxes: Optional[torch.Tensor] = None, mode: Literal["greedy_nodiv", "greedy", "ilp"] = "greedy", normalize_imgs: bool = True, progbar_class=tqdm, n_workers: int = 0, dataname: Optional[str] = None, **kwargs, ) -> TrackGraph: """Track objects across time frames. This method links segmented objects across time frames using the specified tracking mode. No hyperparameters need to be chosen beyond the tracking mode. Args: imgs: Input images of shape (T,(Z),Y,X) (numpy or dask array) masks: Instance segmentation masks of shape (T,(Z),Y,X). mode: Tracking mode: - "greedy_nodiv": Fast greedy linking without division - "greedy": Fast greedy linking with division - "ilp": Integer Linear Programming based linking (more accurate but slower) progbar_class: Progress bar class to use. n_workers: Number of worker processes for feature extraction. normalize_imgs: Whether to normalize the images. **kwargs: Additional arguments passed to tracking algorithm. Returns: TrackGraph containing the tracking results. """ self.eval() imgs, imgs_raw, images_stable, tra_imgs, imgs_01, height, width = load_track_images(file_dir) # tra_imgs = torch.from_numpy(imgs_).float().to(self.device) imgs_stable = torch.from_numpy(images_stable).float().to(self.device) imgs_enc = torch.from_numpy(imgs).float().to(self.device) """get segmentation masks first""" self.boxes = None self.adapt_emb = None masks = [] for i, (input_image, input_image_stable) in tqdm(enumerate(zip(imgs_enc, imgs_stable))): input_image = input_image.unsqueeze(0) input_image_stable = input_image_stable.unsqueeze(0) if i == 0: if self.use_box and boxes is not None: self.boxes = boxes.to(self.device) else: self.boxes = None with torch.no_grad(): mask = self.forward_sd(input_image_stable, input_image, self.boxes, height=height, width=width) masks.append(mask) masks = np.stack(masks, axis=0) # (T, H, W) # ------------------------- if not masks.shape == tra_imgs.shape: raise RuntimeError( f"Img shape {tra_imgs.shape} and mask shape {masks.shape} do not match." ) if not tra_imgs.ndim == self.track_model.config["coord_dim"] + 1: raise RuntimeError( f"images should be a sequence of {self.track_model.config['coord_dim']}D images" ) predictions = self._predict( tra_imgs, masks, imgs_enc=imgs_enc, imgs_stable=imgs_stable, boxes=boxes, normalize_imgs=normalize_imgs, progbar_class=progbar_class, n_workers=n_workers, ) track_graph = self._track_from_predictions(predictions, mode=mode, **kwargs) # ctc_tracks, masks_tracked = graph_to_ctc( # track_graph, # masks, # outdir=f"tracked/{dataname}", # ) return track_graph, masks def inference(data_path, box=None): if box is not None: use_box = True else: use_box = False model = TrackingModule(use_box=use_box) load_msg = model.load_state_dict(torch.load("pretrained/microscopy_matching_tra.pth"), strict=True) model.move_to_device(torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')) track_graph, masks = model.track(file_dir=data_path, dataname="inference_sequence") if not os.path.exists(f"tracked_ours_seg_pred3/"): os.makedirs(f"tracked_ours_seg_pred3/") ctc_tracks, masks_tracked = graph_to_ctc( track_graph, masks, outdir=f"tracked_ours_seg_pred3/", ) if __name__ == "__main__": inference(data_path="example_imgs/2D+Time/Fluo-N2DL-HeLa/train/Fluo-N2DL-HeLa/02")