Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,42 +1,60 @@
|
|
| 1 |
import os
|
| 2 |
-
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
| 3 |
import torch
|
| 4 |
import gradio as gr
|
| 5 |
from gradio_image_prompter import ImagePrompter
|
| 6 |
-
from torch.nn import DataParallel
|
| 7 |
from models.counter_infer import build_model
|
| 8 |
from utils.arg_parser import get_argparser
|
| 9 |
from utils.data import resize_and_pad
|
| 10 |
import torchvision.ops as ops
|
| 11 |
from torchvision import transforms as T
|
| 12 |
-
from PIL import Image, ImageDraw
|
| 13 |
import numpy as np
|
| 14 |
|
| 15 |
-
#
|
|
|
|
|
|
|
| 16 |
def load_model():
|
| 17 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 18 |
-
|
|
|
|
| 19 |
args.zero_shot = True
|
| 20 |
-
|
| 21 |
-
model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
model.eval()
|
|
|
|
| 23 |
return model, device
|
| 24 |
|
| 25 |
model, device = load_model()
|
| 26 |
|
| 27 |
-
#
|
|
|
|
|
|
|
| 28 |
def process_image_once(inputs, enable_mask):
|
| 29 |
-
model.
|
| 30 |
|
| 31 |
image = inputs['image']
|
| 32 |
drawn_boxes = inputs['points']
|
|
|
|
| 33 |
image_tensor = torch.tensor(image).to(device)
|
| 34 |
image_tensor = image_tensor.permute(2, 0, 1).float() / 255.0
|
| 35 |
-
image_tensor = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(image_tensor)
|
| 36 |
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
img, bboxes, scale = resize_and_pad(image_tensor, bboxes_tensor, size=1024.0)
|
|
|
|
| 40 |
img = img.unsqueeze(0).to(device)
|
| 41 |
bboxes = bboxes.unsqueeze(0).to(device)
|
| 42 |
|
|
@@ -45,69 +63,44 @@ def process_image_once(inputs, enable_mask):
|
|
| 45 |
|
| 46 |
return image, outputs, masks, img, scale, drawn_boxes
|
| 47 |
|
| 48 |
-
#
|
|
|
|
|
|
|
| 49 |
def post_process(image, outputs, masks, img, scale, drawn_boxes, enable_mask, threshold):
|
| 50 |
idx = 0
|
| 51 |
-
threshold = 1/threshold
|
| 52 |
-
|
| 53 |
-
|
|
|
|
| 54 |
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
pred_boxes = torch.clamp(pred_boxes, 0, 1)
|
| 57 |
|
| 58 |
pred_boxes = (pred_boxes.cpu() / scale * img.shape[-1]).tolist()
|
| 59 |
|
| 60 |
-
image = Image.fromarray(
|
| 61 |
-
|
| 62 |
-
if enable_mask:
|
| 63 |
-
from matplotlib import pyplot as plt
|
| 64 |
-
masks_ = masks[idx][(outputs[idx]['box_v'] > outputs[idx]['box_v'].max() / threshold)[0]]
|
| 65 |
-
N_masks = masks_.shape[0]
|
| 66 |
-
indices = torch.randint(1, N_masks + 1, (1, N_masks), device=masks_.device).view(-1, 1, 1)
|
| 67 |
-
masks = (masks_ * indices).sum(dim=0)
|
| 68 |
-
mask_display = (
|
| 69 |
-
T.Resize((int(img.shape[2] / scale), int(img.shape[3] / scale)), interpolation=T.InterpolationMode.NEAREST)(
|
| 70 |
-
masks.cpu().unsqueeze(0))[0])[:image.size[1], :image.size[0]]
|
| 71 |
-
cmap = plt.cm.tab20
|
| 72 |
-
norm = plt.Normalize(vmin=0, vmax=N_masks)
|
| 73 |
-
del masks
|
| 74 |
-
del masks_
|
| 75 |
-
del outputs
|
| 76 |
-
rgba_image = cmap(norm(mask_display))
|
| 77 |
-
rgba_image[mask_display == 0, -1] = 0
|
| 78 |
-
rgba_image[mask_display != 0, -1] = 0.5
|
| 79 |
-
|
| 80 |
-
overlay = Image.fromarray((rgba_image * 255).astype(np.uint8), mode="RGBA")
|
| 81 |
-
image = image.convert("RGBA")
|
| 82 |
-
image = Image.alpha_composite(image, overlay)
|
| 83 |
-
|
| 84 |
|
| 85 |
draw = ImageDraw.Draw(image)
|
|
|
|
| 86 |
for box in pred_boxes:
|
| 87 |
draw.rectangle([box[0], box[1], box[2], box[3]], outline="orange", width=5)
|
| 88 |
-
# for box in drawn_boxes:
|
| 89 |
-
# draw.rectangle([box[0], box[1], box[3], box[4]], outline="red", width=3)
|
| 90 |
-
|
| 91 |
-
width, height = image.size
|
| 92 |
-
square_size = int(0.05 * width)
|
| 93 |
-
x1, y1 = 10, height - square_size - 10
|
| 94 |
-
x2, y2 = x1 + square_size, y1 + square_size
|
| 95 |
-
|
| 96 |
-
# draw.rectangle([x1, y1, x2, y2], outline="black", fill="black", width=1)
|
| 97 |
-
# font = ImageFont.load_default()
|
| 98 |
-
# txt = str(len(pred_boxes))
|
| 99 |
-
# w = draw.textlength(txt, font=font)
|
| 100 |
-
# text_x = x1 + (square_size - w) / 2
|
| 101 |
-
# text_y = y1 + (square_size - 10) / 2
|
| 102 |
-
# draw.text((text_x, text_y), txt, fill="white", font=font)
|
| 103 |
|
| 104 |
return image, len(pred_boxes)
|
| 105 |
|
| 106 |
-
|
|
|
|
|
|
|
| 107 |
iface = gr.Blocks()
|
| 108 |
|
| 109 |
with iface:
|
| 110 |
-
# Store intermediate states
|
| 111 |
image_input = gr.State()
|
| 112 |
outputs_state = gr.State()
|
| 113 |
masks_state = gr.State()
|
|
@@ -115,45 +108,34 @@ with iface:
|
|
| 115 |
scale_state = gr.State()
|
| 116 |
drawn_boxes_state = gr.State()
|
| 117 |
|
| 118 |
-
# UI Layout: Input Section
|
| 119 |
with gr.Row():
|
| 120 |
image_prompter = ImagePrompter()
|
| 121 |
image_output = gr.Image(type="pil")
|
| 122 |
-
|
| 123 |
|
| 124 |
-
# UI Layout: Output Section
|
| 125 |
with gr.Row():
|
| 126 |
count_output = gr.Number(label="Total Count")
|
| 127 |
-
enable_mask = gr.Checkbox(label="Predict masks", value=True)
|
| 128 |
-
threshold = gr.Slider(0.05, 0.95, value=0.33, step=0.01
|
| 129 |
-
|
| 130 |
|
| 131 |
-
# Create the 'Count' button
|
| 132 |
count_button = gr.Button("Count")
|
| 133 |
|
| 134 |
-
# Process image once when "Count" button is pressed
|
| 135 |
def initial_process(inputs, enable_mask, threshold):
|
| 136 |
-
# Perform inference once
|
| 137 |
image, outputs, masks, img, scale, drawn_boxes = process_image_once(inputs, enable_mask)
|
| 138 |
|
| 139 |
-
# Save intermediate states
|
| 140 |
return (
|
| 141 |
-
*post_process(image, outputs, masks, img, scale, drawn_boxes, enable_mask, threshold),
|
| 142 |
-
image, outputs, masks, img, scale, drawn_boxes
|
| 143 |
)
|
| 144 |
|
| 145 |
-
# Update image and count when the threshold slider changes (post-process only)
|
| 146 |
def update_threshold(threshold, image, outputs, masks, img, scale, drawn_boxes, enable_mask):
|
| 147 |
return post_process(image, outputs, masks, img, scale, drawn_boxes, enable_mask, threshold)
|
| 148 |
|
| 149 |
-
# Run initial inference and post-process when "Count" button is clicked
|
| 150 |
count_button.click(
|
| 151 |
initial_process,
|
| 152 |
-
[image_prompter, enable_mask, threshold],
|
| 153 |
-
[image_output, count_output, image_input, outputs_state, masks_state, img_state, scale_state, drawn_boxes_state]
|
| 154 |
)
|
| 155 |
|
| 156 |
-
# Adjust the output dynamically based on the threshold slider (no re-inference)
|
| 157 |
threshold.change(
|
| 158 |
update_threshold,
|
| 159 |
[threshold, image_input, outputs_state, masks_state, img_state, scale_state, drawn_boxes_state, enable_mask],
|
|
@@ -161,10 +143,9 @@ with iface:
|
|
| 161 |
)
|
| 162 |
|
| 163 |
enable_mask.change(
|
| 164 |
-
|
| 165 |
[threshold, image_input, outputs_state, masks_state, img_state, scale_state, drawn_boxes_state, enable_mask],
|
| 166 |
[image_output, count_output]
|
| 167 |
)
|
| 168 |
|
| 169 |
-
iface.launch(
|
| 170 |
-
|
|
|
|
| 1 |
import os
|
|
|
|
| 2 |
import torch
|
| 3 |
import gradio as gr
|
| 4 |
from gradio_image_prompter import ImagePrompter
|
|
|
|
| 5 |
from models.counter_infer import build_model
|
| 6 |
from utils.arg_parser import get_argparser
|
| 7 |
from utils.data import resize_and_pad
|
| 8 |
import torchvision.ops as ops
|
| 9 |
from torchvision import transforms as T
|
| 10 |
+
from PIL import Image, ImageDraw
|
| 11 |
import numpy as np
|
| 12 |
|
| 13 |
+
# -----------------------
|
| 14 |
+
# LOAD MODEL
|
| 15 |
+
# -----------------------
|
| 16 |
def load_model():
|
| 17 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 18 |
+
|
| 19 |
+
args = get_argparser().parse_args([])
|
| 20 |
args.zero_shot = True
|
| 21 |
+
|
| 22 |
+
model = build_model(args).to(device)
|
| 23 |
+
|
| 24 |
+
# ⚠️ Make sure this file exists in your repo
|
| 25 |
+
checkpoint = torch.load("CNTQG_multitrain_ca44.pth", map_location=device)
|
| 26 |
+
|
| 27 |
+
model.load_state_dict(checkpoint["model"], strict=False)
|
| 28 |
model.eval()
|
| 29 |
+
|
| 30 |
return model, device
|
| 31 |
|
| 32 |
model, device = load_model()
|
| 33 |
|
| 34 |
+
# -----------------------
|
| 35 |
+
# PROCESS IMAGE
|
| 36 |
+
# -----------------------
|
| 37 |
def process_image_once(inputs, enable_mask):
|
| 38 |
+
model.return_masks = enable_mask # ✅ FIXED
|
| 39 |
|
| 40 |
image = inputs['image']
|
| 41 |
drawn_boxes = inputs['points']
|
| 42 |
+
|
| 43 |
image_tensor = torch.tensor(image).to(device)
|
| 44 |
image_tensor = image_tensor.permute(2, 0, 1).float() / 255.0
|
|
|
|
| 45 |
|
| 46 |
+
image_tensor = T.Normalize(
|
| 47 |
+
mean=[0.485, 0.456, 0.406],
|
| 48 |
+
std=[0.229, 0.224, 0.225]
|
| 49 |
+
)(image_tensor)
|
| 50 |
+
|
| 51 |
+
bboxes_tensor = torch.tensor(
|
| 52 |
+
[[box[0], box[1], box[3], box[4]] for box in drawn_boxes],
|
| 53 |
+
dtype=torch.float32
|
| 54 |
+
).to(device)
|
| 55 |
|
| 56 |
img, bboxes, scale = resize_and_pad(image_tensor, bboxes_tensor, size=1024.0)
|
| 57 |
+
|
| 58 |
img = img.unsqueeze(0).to(device)
|
| 59 |
bboxes = bboxes.unsqueeze(0).to(device)
|
| 60 |
|
|
|
|
| 63 |
|
| 64 |
return image, outputs, masks, img, scale, drawn_boxes
|
| 65 |
|
| 66 |
+
# -----------------------
|
| 67 |
+
# POST PROCESS
|
| 68 |
+
# -----------------------
|
| 69 |
def post_process(image, outputs, masks, img, scale, drawn_boxes, enable_mask, threshold):
|
| 70 |
idx = 0
|
| 71 |
+
threshold = 1 / threshold
|
| 72 |
+
|
| 73 |
+
scores = outputs[idx]['box_v']
|
| 74 |
+
boxes = outputs[idx]['pred_boxes']
|
| 75 |
|
| 76 |
+
keep_mask = scores > scores.max() / threshold
|
| 77 |
+
|
| 78 |
+
keep = ops.nms(
|
| 79 |
+
boxes[keep_mask],
|
| 80 |
+
scores[keep_mask],
|
| 81 |
+
0.5
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
pred_boxes = boxes[keep_mask][keep]
|
| 85 |
pred_boxes = torch.clamp(pred_boxes, 0, 1)
|
| 86 |
|
| 87 |
pred_boxes = (pred_boxes.cpu() / scale * img.shape[-1]).tolist()
|
| 88 |
|
| 89 |
+
image = Image.fromarray(image.astype(np.uint8))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
draw = ImageDraw.Draw(image)
|
| 92 |
+
|
| 93 |
for box in pred_boxes:
|
| 94 |
draw.rectangle([box[0], box[1], box[2], box[3]], outline="orange", width=5)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
|
| 96 |
return image, len(pred_boxes)
|
| 97 |
|
| 98 |
+
# -----------------------
|
| 99 |
+
# GRADIO UI
|
| 100 |
+
# -----------------------
|
| 101 |
iface = gr.Blocks()
|
| 102 |
|
| 103 |
with iface:
|
|
|
|
| 104 |
image_input = gr.State()
|
| 105 |
outputs_state = gr.State()
|
| 106 |
masks_state = gr.State()
|
|
|
|
| 108 |
scale_state = gr.State()
|
| 109 |
drawn_boxes_state = gr.State()
|
| 110 |
|
|
|
|
| 111 |
with gr.Row():
|
| 112 |
image_prompter = ImagePrompter()
|
| 113 |
image_output = gr.Image(type="pil")
|
|
|
|
| 114 |
|
|
|
|
| 115 |
with gr.Row():
|
| 116 |
count_output = gr.Number(label="Total Count")
|
| 117 |
+
enable_mask = gr.Checkbox(label="Predict masks", value=True)
|
| 118 |
+
threshold = gr.Slider(0.05, 0.95, value=0.33, step=0.01)
|
|
|
|
| 119 |
|
|
|
|
| 120 |
count_button = gr.Button("Count")
|
| 121 |
|
|
|
|
| 122 |
def initial_process(inputs, enable_mask, threshold):
|
|
|
|
| 123 |
image, outputs, masks, img, scale, drawn_boxes = process_image_once(inputs, enable_mask)
|
| 124 |
|
|
|
|
| 125 |
return (
|
| 126 |
+
*post_process(image, outputs, masks, img, scale, drawn_boxes, enable_mask, threshold),
|
| 127 |
+
image, outputs, masks, img, scale, drawn_boxes
|
| 128 |
)
|
| 129 |
|
|
|
|
| 130 |
def update_threshold(threshold, image, outputs, masks, img, scale, drawn_boxes, enable_mask):
|
| 131 |
return post_process(image, outputs, masks, img, scale, drawn_boxes, enable_mask, threshold)
|
| 132 |
|
|
|
|
| 133 |
count_button.click(
|
| 134 |
initial_process,
|
| 135 |
+
[image_prompter, enable_mask, threshold],
|
| 136 |
+
[image_output, count_output, image_input, outputs_state, masks_state, img_state, scale_state, drawn_boxes_state]
|
| 137 |
)
|
| 138 |
|
|
|
|
| 139 |
threshold.change(
|
| 140 |
update_threshold,
|
| 141 |
[threshold, image_input, outputs_state, masks_state, img_state, scale_state, drawn_boxes_state, enable_mask],
|
|
|
|
| 143 |
)
|
| 144 |
|
| 145 |
enable_mask.change(
|
| 146 |
+
update_threshold,
|
| 147 |
[threshold, image_input, outputs_state, masks_state, img_state, scale_state, drawn_boxes_state, enable_mask],
|
| 148 |
[image_output, count_output]
|
| 149 |
)
|
| 150 |
|
| 151 |
+
iface.launch()
|
|
|