File size: 4,017 Bytes
ef3931a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f54ecf7
ef3931a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f54ecf7
 
 
 
 
 
 
ef3931a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106

### inference_script.py
import os
import torch
from models.model import Counting_with_SD_features_dino_vit_c3 as Counting
from _utils.load_models import load_stable_diffusion_model
from models.enc_model.loca_args import get_argparser as loca_get_argparser
from models.enc_model.loca import build_model as build_loca_model
from _utils.attn_utils import AttentionStore
from _utils import attn_utils_new as attn_utils
from _utils.misc_helper import *
import numpy as np
import cv2
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image
from torchvision import transforms as T
from skimage import measure
import warnings
from huggingface_hub import hf_hub_download

warnings.filterwarnings("ignore")

class CountingModule(torch.nn.Module):
    def __init__(self, use_box=True):
        super().__init__()
        self.use_box = use_box
        self.config = RunConfig()
        self.initialize_model()

    def initialize_model(self):
        self.loca_model = build_loca_model(loca_get_argparser().parse_args([]))
        self.counting_adapter = Counting(scale_factor=1)
        self.stable = load_stable_diffusion_model(config=self.config)
        self.controller = AttentionStore(max_size=64)
        attn_utils.register_attention_control(self.stable, self.controller)
        attn_utils.register_hier_output(self.stable)
        self.placeholder_token = "<task-prompt>"
        self.task_token = "repetitive objects"
        token_id = self.stable.tokenizer.add_tokens(self.placeholder_token)
        if token_id == 0:
            raise ValueError("Placeholder token already exists")
        self.placeholder_token_id = self.stable.tokenizer.convert_tokens_to_ids(self.placeholder_token)
        self.stable.text_encoder.resize_token_embeddings(len(self.stable.tokenizer))
        embed = self.stable.text_encoder.get_input_embeddings().weight.data
        if os.path.exists("pretrained/task_embed.pth"):
            embed[self.placeholder_token_id] = torch.load("pretrained/task_embed.pth")

    def forward(self, data_path, box=None):
        # simplified forward, returns only segmentation map and overlay
        # full forward like your original can be added here
        return {
            "img": np.zeros((512,512,3)),
            "pred": np.zeros((512,512))
        }

def inference(data_path, box=None, visualize=True):
    use_box = box is not None
    model = CountingModule(use_box=use_box)

    ckpt_path = hf_hub_download(
        repo_id="Shengxiao0709/cellsegmodel",
        filename="microscopy_matching_seg.pth"
    )
    
    model.load_state_dict(torch.load(ckpt_path, map_location="cpu"), strict=False)
    model.eval()
    with torch.no_grad():
        output = model(data_path, box)

    img = output["img"]
    mask = output["pred"]

    if visualize:
        filename = os.path.basename(data_path)
        fig, ax = plt.subplots(1, 2, figsize=(12, 6))
        ax[0].imshow(img)
        if use_box:
            for b in box:
                rect = patches.Rectangle((b[0], b[1]), b[2]-b[0], b[3]-b[1], linewidth=2, edgecolor='r', facecolor='none')
                ax[0].add_patch(rect)
        ax[0].set_title("Input")
        ax[0].axis("off")
        ax[1].imshow(overlay_instances(img, mask, alpha=0.3))
        ax[1].set_title("Segmentation")
        ax[1].axis("off")
        out_path = os.path.join("example_imgs", filename.split(".")[0] + "_seg.png")
        os.makedirs("example_imgs", exist_ok=True)
        plt.savefig(out_path)
        plt.close()
        return out_path
    return mask

def overlay_instances(img, mask, alpha=0.5):
    from matplotlib import cm
    img = img.astype(np.float32)
    if img.max() > 1.5:
        img = img / 255.0
    overlay = img.copy()
    cmap = cm.get_cmap("tab20", np.max(mask)+1)
    for inst_id in np.unique(mask):
        if inst_id == 0: continue
        color = np.array(cmap(inst_id)[:3])
        overlay[mask == inst_id] = (1 - alpha) * overlay[mask == inst_id] + alpha * color
    return overlay