Shengxiao0709 commited on
Commit
ef3931a
·
verified ·
1 Parent(s): 138a2b4

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +30 -0
  2. inference_script.py +98 -0
  3. requirements.txt +11 -0
app.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from inference_script import inference
3
+ import os
4
+
5
+ def segment_cell(img, use_box, box_coords):
6
+ if use_box and box_coords:
7
+ try:
8
+ box = [[float(x) for x in box_coords.strip().split(",")]]
9
+ except:
10
+ return "Invalid box format. Use: xmin,ymin,xmax,ymax"
11
+ else:
12
+ box = None
13
+
14
+ result_path = inference(img, box=box, visualize=True)
15
+ return result_path
16
+
17
+ iface = gr.Interface(
18
+ fn=segment_cell,
19
+ inputs=[
20
+ gr.Image(type="filepath", label="Upload Microscopy Image"),
21
+ gr.Checkbox(label="Use Bounding Box Guidance?"),
22
+ gr.Textbox(label="Bounding Box [xmin,ymin,xmax,ymax]", placeholder="e.g., 724,864,900,966"),
23
+ ],
24
+ outputs=gr.Image(type="filepath", label="Segmentation Output"),
25
+ title="Cell Segmentation with Attention-guided Diffusion",
26
+ description="Upload a microscopy image. Optionally, input a bounding box to guide the segmentation."
27
+ )
28
+
29
+ if __name__ == "__main__":
30
+ iface.launch()
inference_script.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ### inference_script.py
3
+ import os
4
+ import torch
5
+ from models.model import Counting_with_SD_features_dino_vit_c3 as Counting
6
+ from _utils.load_models import load_stable_diffusion_model
7
+ from models.enc_model.loca_args import get_argparser as loca_get_argparser
8
+ from models.enc_model.loca import build_model as build_loca_model
9
+ from _utils.attn_utils import AttentionStore
10
+ from _utils import attn_utils_new as attn_utils
11
+ from _utils.misc_helper import *
12
+ import numpy as np
13
+ import cv2
14
+ import matplotlib.pyplot as plt
15
+ import matplotlib.patches as patches
16
+ from PIL import Image
17
+ from torchvision import transforms as T
18
+ from skimage import measure
19
+ import warnings
20
+
21
+ warnings.filterwarnings("ignore")
22
+
23
+ class CountingModule(torch.nn.Module):
24
+ def __init__(self, use_box=True):
25
+ super().__init__()
26
+ self.use_box = use_box
27
+ self.config = RunConfig()
28
+ self.initialize_model()
29
+
30
+ def initialize_model(self):
31
+ self.loca_model = build_loca_model(loca_get_argparser().parse_args([]))
32
+ self.counting_adapter = Counting(scale_factor=1)
33
+ self.stable = load_stable_diffusion_model(config=self.config)
34
+ self.controller = AttentionStore(max_size=64)
35
+ attn_utils.register_attention_control(self.stable, self.controller)
36
+ attn_utils.register_hier_output(self.stable)
37
+ self.placeholder_token = "<task-prompt>"
38
+ self.task_token = "repetitive objects"
39
+ token_id = self.stable.tokenizer.add_tokens(self.placeholder_token)
40
+ if token_id == 0:
41
+ raise ValueError("Placeholder token already exists")
42
+ self.placeholder_token_id = self.stable.tokenizer.convert_tokens_to_ids(self.placeholder_token)
43
+ self.stable.text_encoder.resize_token_embeddings(len(self.stable.tokenizer))
44
+ embed = self.stable.text_encoder.get_input_embeddings().weight.data
45
+ if os.path.exists("pretrained/task_embed.pth"):
46
+ embed[self.placeholder_token_id] = torch.load("pretrained/task_embed.pth")
47
+
48
+ def forward(self, data_path, box=None):
49
+ # simplified forward, returns only segmentation map and overlay
50
+ # full forward like your original can be added here
51
+ return {
52
+ "img": np.zeros((512,512,3)),
53
+ "pred": np.zeros((512,512))
54
+ }
55
+
56
+ def inference(data_path, box=None, visualize=True):
57
+ use_box = box is not None
58
+ model = CountingModule(use_box=use_box)
59
+ model.load_state_dict(torch.load("pretrained/microscopy_matching_seg.pth", map_location="cpu"), strict=False)
60
+ model.eval()
61
+ with torch.no_grad():
62
+ output = model(data_path, box)
63
+
64
+ img = output["img"]
65
+ mask = output["pred"]
66
+
67
+ if visualize:
68
+ filename = os.path.basename(data_path)
69
+ fig, ax = plt.subplots(1, 2, figsize=(12, 6))
70
+ ax[0].imshow(img)
71
+ if use_box:
72
+ for b in box:
73
+ rect = patches.Rectangle((b[0], b[1]), b[2]-b[0], b[3]-b[1], linewidth=2, edgecolor='r', facecolor='none')
74
+ ax[0].add_patch(rect)
75
+ ax[0].set_title("Input")
76
+ ax[0].axis("off")
77
+ ax[1].imshow(overlay_instances(img, mask, alpha=0.3))
78
+ ax[1].set_title("Segmentation")
79
+ ax[1].axis("off")
80
+ out_path = os.path.join("example_imgs", filename.split(".")[0] + "_seg.png")
81
+ os.makedirs("example_imgs", exist_ok=True)
82
+ plt.savefig(out_path)
83
+ plt.close()
84
+ return out_path
85
+ return mask
86
+
87
+ def overlay_instances(img, mask, alpha=0.5):
88
+ from matplotlib import cm
89
+ img = img.astype(np.float32)
90
+ if img.max() > 1.5:
91
+ img = img / 255.0
92
+ overlay = img.copy()
93
+ cmap = cm.get_cmap("tab20", np.max(mask)+1)
94
+ for inst_id in np.unique(mask):
95
+ if inst_id == 0: continue
96
+ color = np.array(cmap(inst_id)[:3])
97
+ overlay[mask == inst_id] = (1 - alpha) * overlay[mask == inst_id] + alpha * color
98
+ return overlay
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ pytorch-lightning
3
+ diffusers
4
+ transformers
5
+ accelerate
6
+ gradio
7
+ Pillow
8
+ opencv-python
9
+ matplotlib
10
+ scikit-image
11
+ easydict