Spaces:
Runtime error
Runtime error
| ### inference_script.py | |
| import os | |
| import torch | |
| from models.model import Counting_with_SD_features_dino_vit_c3 as Counting | |
| from _utils.load_models import load_stable_diffusion_model | |
| from models.enc_model.loca_args import get_argparser as loca_get_argparser | |
| from models.enc_model.loca import build_model as build_loca_model | |
| from _utils.attn_utils import AttentionStore | |
| from _utils import attn_utils_new as attn_utils | |
| from _utils.misc_helper import * | |
| import numpy as np | |
| import cv2 | |
| import matplotlib.pyplot as plt | |
| import matplotlib.patches as patches | |
| from PIL import Image | |
| from torchvision import transforms as T | |
| from skimage import measure | |
| import warnings | |
| from huggingface_hub import hf_hub_download | |
| warnings.filterwarnings("ignore") | |
| class CountingModule(torch.nn.Module): | |
| def __init__(self, use_box=True): | |
| super().__init__() | |
| self.use_box = use_box | |
| self.config = RunConfig() | |
| self.initialize_model() | |
| def initialize_model(self): | |
| self.loca_model = build_loca_model(loca_get_argparser().parse_args([])) | |
| self.counting_adapter = Counting(scale_factor=1) | |
| self.stable = load_stable_diffusion_model(config=self.config) | |
| self.controller = AttentionStore(max_size=64) | |
| attn_utils.register_attention_control(self.stable, self.controller) | |
| attn_utils.register_hier_output(self.stable) | |
| self.placeholder_token = "<task-prompt>" | |
| self.task_token = "repetitive objects" | |
| token_id = self.stable.tokenizer.add_tokens(self.placeholder_token) | |
| if token_id == 0: | |
| raise ValueError("Placeholder token already exists") | |
| self.placeholder_token_id = self.stable.tokenizer.convert_tokens_to_ids(self.placeholder_token) | |
| self.stable.text_encoder.resize_token_embeddings(len(self.stable.tokenizer)) | |
| embed = self.stable.text_encoder.get_input_embeddings().weight.data | |
| if os.path.exists("pretrained/task_embed.pth"): | |
| embed[self.placeholder_token_id] = torch.load("pretrained/task_embed.pth") | |
| def forward(self, data_path, box=None): | |
| # simplified forward, returns only segmentation map and overlay | |
| # full forward like your original can be added here | |
| return { | |
| "img": np.zeros((512,512,3)), | |
| "pred": np.zeros((512,512)) | |
| } | |
| def inference(data_path, box=None, visualize=True): | |
| use_box = box is not None | |
| model = CountingModule(use_box=use_box) | |
| ckpt_path = hf_hub_download( | |
| repo_id="Shengxiao0709/cellsegmodel", | |
| filename="microscopy_matching_seg.pth" | |
| ) | |
| model.load_state_dict(torch.load(ckpt_path, map_location="cpu"), strict=False) | |
| model.eval() | |
| with torch.no_grad(): | |
| output = model(data_path, box) | |
| img = output["img"] | |
| mask = output["pred"] | |
| if visualize: | |
| filename = os.path.basename(data_path) | |
| fig, ax = plt.subplots(1, 2, figsize=(12, 6)) | |
| ax[0].imshow(img) | |
| if use_box: | |
| for b in box: | |
| rect = patches.Rectangle((b[0], b[1]), b[2]-b[0], b[3]-b[1], linewidth=2, edgecolor='r', facecolor='none') | |
| ax[0].add_patch(rect) | |
| ax[0].set_title("Input") | |
| ax[0].axis("off") | |
| ax[1].imshow(overlay_instances(img, mask, alpha=0.3)) | |
| ax[1].set_title("Segmentation") | |
| ax[1].axis("off") | |
| out_path = os.path.join("example_imgs", filename.split(".")[0] + "_seg.png") | |
| os.makedirs("example_imgs", exist_ok=True) | |
| plt.savefig(out_path) | |
| plt.close() | |
| return out_path | |
| return mask | |
| def overlay_instances(img, mask, alpha=0.5): | |
| from matplotlib import cm | |
| img = img.astype(np.float32) | |
| if img.max() > 1.5: | |
| img = img / 255.0 | |
| overlay = img.copy() | |
| cmap = cm.get_cmap("tab20", np.max(mask)+1) | |
| for inst_id in np.unique(mask): | |
| if inst_id == 0: continue | |
| color = np.array(cmap(inst_id)[:3]) | |
| overlay[mask == inst_id] = (1 - alpha) * overlay[mask == inst_id] + alpha * color | |
| return overlay | |