Update app.py
Browse files
app.py
CHANGED
|
@@ -154,8 +154,6 @@ allowed_tags = list(tags.keys())
|
|
| 154 |
for idx, tag in enumerate(allowed_tags):
|
| 155 |
allowed_tags[idx] = tag.replace("_", " ")
|
| 156 |
|
| 157 |
-
|
| 158 |
-
|
| 159 |
@spaces.GPU(duration=5)
|
| 160 |
def run_classifier(image: Image.Image, threshold):
|
| 161 |
img = image.convert('RGBA')
|
|
@@ -186,9 +184,6 @@ def cam_inference(img, threshold, alpha, evt: gr.SelectData):
|
|
| 186 |
|
| 187 |
gradients = {}
|
| 188 |
activations = {}
|
| 189 |
-
cam = None
|
| 190 |
-
target_tag_index = None
|
| 191 |
-
|
| 192 |
|
| 193 |
def hook_forward(module, input, output):
|
| 194 |
activations['value'] = output
|
|
@@ -200,29 +195,24 @@ def cam_inference(img, threshold, alpha, evt: gr.SelectData):
|
|
| 200 |
handle_forward = model.norm.register_forward_hook(hook_forward)
|
| 201 |
handle_backward = model.norm.register_full_backward_hook(hook_backward)
|
| 202 |
|
| 203 |
-
probits = model(tensor)[0]
|
| 204 |
|
| 205 |
model.zero_grad()
|
| 206 |
-
|
| 207 |
-
target_score.backward(retain_graph=True)
|
| 208 |
-
|
| 209 |
-
grads = gradients.get('value')
|
| 210 |
-
acts = activations.get('value')
|
| 211 |
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
|
|
|
| 221 |
|
| 222 |
handle_forward.remove()
|
| 223 |
handle_backward.remove()
|
| 224 |
-
gradients = {}
|
| 225 |
-
activations = {}
|
| 226 |
|
| 227 |
return create_cam_visualization_pil(img, cam, alpha=alpha, vis_threshold=threshold), cam
|
| 228 |
|
|
@@ -245,26 +235,30 @@ def create_cam_visualization_pil(image_pil, cam, alpha=0.6, vis_threshold=0.2):
|
|
| 245 |
size = max(w, h)
|
| 246 |
|
| 247 |
# Normalize CAM to [0, 1]
|
| 248 |
-
|
|
|
|
| 249 |
|
| 250 |
# Create heatmap using matplotlib colormap
|
| 251 |
colormap = cm.get_cmap('inferno')
|
| 252 |
-
|
| 253 |
-
cam_alpha = (cam_norm >= vis_threshold).astype(np.float32) * alpha # Alpha mask
|
| 254 |
|
| 255 |
-
|
|
|
|
|
|
|
| 256 |
|
| 257 |
# Resize CAM to match image
|
| 258 |
-
|
|
|
|
| 259 |
|
| 260 |
-
|
|
|
|
|
|
|
| 261 |
|
| 262 |
# Composite over original
|
| 263 |
composite = Image.alpha_composite(image_pil, cam_image)
|
| 264 |
|
| 265 |
return composite
|
| 266 |
|
| 267 |
-
|
| 268 |
with gr.Blocks(css=".output-class { display: none; }") as demo:
|
| 269 |
gr.Markdown("""
|
| 270 |
## Joint Tagger Project: JTP-PILOT² Demo **BETA**
|
|
@@ -280,10 +274,20 @@ with gr.Blocks(css=".output-class { display: none; }") as demo:
|
|
| 280 |
sorted_tag_score_state = gr.State(value={}) # stash a copy of the input image
|
| 281 |
cam_state = gr.State()
|
| 282 |
with gr.Row():
|
| 283 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 284 |
image_input = gr.Image(label="Source", sources=['upload'], type='pil', height=512, show_label=False)
|
| 285 |
threshold_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.20, label="Tag Threshold")
|
| 286 |
-
cam_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.
|
| 287 |
alpha_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.60, label="CAM Alpha")
|
| 288 |
with gr.Column():
|
| 289 |
tag_string = gr.Textbox(label="Tag String")
|
|
|
|
| 154 |
for idx, tag in enumerate(allowed_tags):
|
| 155 |
allowed_tags[idx] = tag.replace("_", " ")
|
| 156 |
|
|
|
|
|
|
|
| 157 |
@spaces.GPU(duration=5)
|
| 158 |
def run_classifier(image: Image.Image, threshold):
|
| 159 |
img = image.convert('RGBA')
|
|
|
|
| 184 |
|
| 185 |
gradients = {}
|
| 186 |
activations = {}
|
|
|
|
|
|
|
|
|
|
| 187 |
|
| 188 |
def hook_forward(module, input, output):
|
| 189 |
activations['value'] = output
|
|
|
|
| 195 |
handle_forward = model.norm.register_forward_hook(hook_forward)
|
| 196 |
handle_backward = model.norm.register_full_backward_hook(hook_backward)
|
| 197 |
|
| 198 |
+
probits = model(tensor)[0]
|
| 199 |
|
| 200 |
model.zero_grad()
|
| 201 |
+
probits[target_tag_index].backward(retain_graph=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
|
| 203 |
+
with torch.no_grad():
|
| 204 |
+
patch_grads = gradients.get('value')
|
| 205 |
+
patch_acts = activations.get('value')
|
| 206 |
+
|
| 207 |
+
weights = torch.mean(patch_grads, dim=1).squeeze(0)
|
| 208 |
+
|
| 209 |
+
cam_1d = torch.einsum('pe,e->p', patch_acts.squeeze(0), weights)
|
| 210 |
+
cam_1d = torch.relu(cam_1d)
|
| 211 |
+
|
| 212 |
+
cam = cam_1d.reshape(27, 27).detach().cpu().numpy()
|
| 213 |
|
| 214 |
handle_forward.remove()
|
| 215 |
handle_backward.remove()
|
|
|
|
|
|
|
| 216 |
|
| 217 |
return create_cam_visualization_pil(img, cam, alpha=alpha, vis_threshold=threshold), cam
|
| 218 |
|
|
|
|
| 235 |
size = max(w, h)
|
| 236 |
|
| 237 |
# Normalize CAM to [0, 1]
|
| 238 |
+
cam -= cam.min()
|
| 239 |
+
cam /= cam.max()
|
| 240 |
|
| 241 |
# Create heatmap using matplotlib colormap
|
| 242 |
colormap = cm.get_cmap('inferno')
|
| 243 |
+
cam_rgb = colormap(cam)[:, :, :3] # RGB
|
|
|
|
| 244 |
|
| 245 |
+
# Create alpha channel
|
| 246 |
+
cam_alpha = (cam >= vis_threshold).astype(np.float32) * alpha # Alpha mask
|
| 247 |
+
cam_rgba = np.dstack((cam_rgb, cam_alpha)) # Shape: (H, W, 4)
|
| 248 |
|
| 249 |
# Resize CAM to match image
|
| 250 |
+
cam_pil = Image.fromarray((cam_rgba * 255).astype(np.uint8), mode="RGBA")
|
| 251 |
+
cam_pil = cam_pil.resize((216,216), resample=Image.Resampling.NEAREST)
|
| 252 |
|
| 253 |
+
# Model uses padded image as input, this matches attention map to input image aspect ratio
|
| 254 |
+
cam_pil = cam_pil.resize((size, size), resample=Image.Resampling.BICUBIC)
|
| 255 |
+
cam_pil = transforms.CenterCrop((h, w))(cam_pil)
|
| 256 |
|
| 257 |
# Composite over original
|
| 258 |
composite = Image.alpha_composite(image_pil, cam_image)
|
| 259 |
|
| 260 |
return composite
|
| 261 |
|
|
|
|
| 262 |
with gr.Blocks(css=".output-class { display: none; }") as demo:
|
| 263 |
gr.Markdown("""
|
| 264 |
## Joint Tagger Project: JTP-PILOT² Demo **BETA**
|
|
|
|
| 274 |
sorted_tag_score_state = gr.State(value={}) # stash a copy of the input image
|
| 275 |
cam_state = gr.State()
|
| 276 |
with gr.Row():
|
| 277 |
+
custom_css = """
|
| 278 |
+
.inferno-slider input[type=range] {
|
| 279 |
+
background: linear-gradient(to right,
|
| 280 |
+
#000004, #1b0c41, #4a0c6b, #781c6d,
|
| 281 |
+
#a52c60, #cf4446, #ed6925, #fb9b06,
|
| 282 |
+
#f7d13d, #fcffa4
|
| 283 |
+
) !important;
|
| 284 |
+
background-size: 100% 100% !important;
|
| 285 |
+
}
|
| 286 |
+
"""
|
| 287 |
+
with gr.Column(css=custom_css):
|
| 288 |
image_input = gr.Image(label="Source", sources=['upload'], type='pil', height=512, show_label=False)
|
| 289 |
threshold_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.20, label="Tag Threshold")
|
| 290 |
+
cam_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.40, label="CAM Threshold", elem_classes="inferno-slider")
|
| 291 |
alpha_slider = gr.Slider(minimum=0.00, maximum=1.00, step=0.01, value=0.60, label="CAM Alpha")
|
| 292 |
with gr.Column():
|
| 293 |
tag_string = gr.Textbox(label="Tag String")
|