### inference_script.py import os import torch from models.model import Counting_with_SD_features_dino_vit_c3 as Counting from _utils.load_models import load_stable_diffusion_model from models.enc_model.loca_args import get_argparser as loca_get_argparser from models.enc_model.loca import build_model as build_loca_model from _utils.attn_utils import AttentionStore from _utils import attn_utils_new as attn_utils from _utils.misc_helper import * import numpy as np import cv2 import matplotlib.pyplot as plt import matplotlib.patches as patches from PIL import Image from torchvision import transforms as T from skimage import measure import warnings from huggingface_hub import hf_hub_download warnings.filterwarnings("ignore") class CountingModule(torch.nn.Module): def __init__(self, use_box=True): super().__init__() self.use_box = use_box self.config = RunConfig() self.initialize_model() def initialize_model(self): self.loca_model = build_loca_model(loca_get_argparser().parse_args([])) self.counting_adapter = Counting(scale_factor=1) self.stable = load_stable_diffusion_model(config=self.config) self.controller = AttentionStore(max_size=64) attn_utils.register_attention_control(self.stable, self.controller) attn_utils.register_hier_output(self.stable) self.placeholder_token = "" self.task_token = "repetitive objects" token_id = self.stable.tokenizer.add_tokens(self.placeholder_token) if token_id == 0: raise ValueError("Placeholder token already exists") self.placeholder_token_id = self.stable.tokenizer.convert_tokens_to_ids(self.placeholder_token) self.stable.text_encoder.resize_token_embeddings(len(self.stable.tokenizer)) embed = self.stable.text_encoder.get_input_embeddings().weight.data if os.path.exists("pretrained/task_embed.pth"): embed[self.placeholder_token_id] = torch.load("pretrained/task_embed.pth") def forward(self, data_path, box=None): # simplified forward, returns only segmentation map and overlay # full forward like your original can be added here return { "img": np.zeros((512,512,3)), "pred": np.zeros((512,512)) } def inference(data_path, box=None, visualize=True): use_box = box is not None model = CountingModule(use_box=use_box) ckpt_path = hf_hub_download( repo_id="Shengxiao0709/cellsegmodel", filename="microscopy_matching_seg.pth" ) model.load_state_dict(torch.load(ckpt_path, map_location="cpu"), strict=False) model.eval() with torch.no_grad(): output = model(data_path, box) img = output["img"] mask = output["pred"] if visualize: filename = os.path.basename(data_path) fig, ax = plt.subplots(1, 2, figsize=(12, 6)) ax[0].imshow(img) if use_box: for b in box: rect = patches.Rectangle((b[0], b[1]), b[2]-b[0], b[3]-b[1], linewidth=2, edgecolor='r', facecolor='none') ax[0].add_patch(rect) ax[0].set_title("Input") ax[0].axis("off") ax[1].imshow(overlay_instances(img, mask, alpha=0.3)) ax[1].set_title("Segmentation") ax[1].axis("off") out_path = os.path.join("example_imgs", filename.split(".")[0] + "_seg.png") os.makedirs("example_imgs", exist_ok=True) plt.savefig(out_path) plt.close() return out_path return mask def overlay_instances(img, mask, alpha=0.5): from matplotlib import cm img = img.astype(np.float32) if img.max() > 1.5: img = img / 255.0 overlay = img.copy() cmap = cm.get_cmap("tab20", np.max(mask)+1) for inst_id in np.unique(mask): if inst_id == 0: continue color = np.array(cmap(inst_id)[:3]) overlay[mask == inst_id] = (1 - alpha) * overlay[mask == inst_id] + alpha * color return overlay