Spaces:
Running on Zero
Running on Zero
| import argparse | |
| import base64 | |
| import os | |
| os.system("pip uninstall -y hf-gradio") | |
| os.system("pip uninstall -y spaces") | |
| os.system("pip uninstall -y mcp") | |
| os.system("pip install -r requirement.txt") | |
| import spaces | |
| from io import BytesIO | |
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| from diffusers import DDIMScheduler | |
| from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor | |
| from ReferenceNet import ReferenceNet | |
| from inpainting_pipeline import StableDiffusionInpaintPipeline | |
| from PIL import Image | |
| from huggingface_hub import snapshot_download | |
| import cv2 | |
| import numpy as np | |
| import math | |
| class LinearResampler(nn.Module): | |
| def __init__(self, input_dim=1024, output_dim=1024): | |
| super().__init__() | |
| self.projector = nn.Linear(input_dim, output_dim) | |
| def forward(self, x): | |
| return self.projector(x) | |
| # ===================== Attention Capture ===================== | |
| class CaptureAttnProcessor(nn.Module): | |
| def __init__(self): | |
| self.captured_attn_map = None | |
| self.cnt = 0 | |
| super().__init__() | |
| def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None): | |
| residual = hidden_states | |
| if attn.spatial_norm is not None: | |
| hidden_states = attn.spatial_norm(hidden_states, temb) | |
| input_ndim = hidden_states.ndim | |
| if input_ndim == 4: | |
| batch_size, channel, height, width = hidden_states.shape | |
| hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) | |
| batch_size, sequence_length, _ = ( | |
| hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | |
| ) | |
| attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | |
| if attn.group_norm is not None: | |
| hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) | |
| query = attn.to_q(hidden_states) | |
| if encoder_hidden_states is None: | |
| encoder_hidden_states = hidden_states | |
| elif attn.norm_cross: | |
| encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) | |
| key = attn.to_k(encoder_hidden_states) | |
| value = attn.to_v(encoder_hidden_states) | |
| query = attn.head_to_batch_dim(query) | |
| key = attn.head_to_batch_dim(key) | |
| value = attn.head_to_batch_dim(value) | |
| attention_probs = attn.get_attention_scores(query, key, attention_mask) | |
| if self.cnt % 3 == 0: | |
| if attention_probs.shape[0] > 8: | |
| self.captured_attn_map = attention_probs[8:, :, :].detach() | |
| else: | |
| self.captured_attn_map = attention_probs.detach() | |
| self.cnt += 1 | |
| hidden_states = torch.bmm(attention_probs, value) | |
| hidden_states = attn.batch_to_head_dim(hidden_states) | |
| hidden_states = attn.to_out[0](hidden_states) | |
| hidden_states = attn.to_out[1](hidden_states) | |
| if input_ndim == 4: | |
| hidden_states = hidden_states.transpose(-1, 2).reshape(batch_size, channel, height, width) | |
| if attn.residual_connection: | |
| hidden_states = hidden_states + residual | |
| hidden_states = hidden_states / attn.rescale_output_factor | |
| return hidden_states | |
| def visualize_attention_map(hooks_dict, inpainting_mask=None, ref_image=None): | |
| """ | |
| Visualize the attention map and return a PIL Image (without saving to disk). | |
| Uses cv2 COLORMAP_JET heatmap overlaid on the reference image. | |
| """ | |
| valid_hooks = {k: v for k, v in hooks_dict.items() if v.captured_attn_map is not None} | |
| num_layers = len(valid_hooks) | |
| if num_layers == 0: | |
| print("No attention maps captured.") | |
| return None | |
| sorted_keys = sorted(valid_hooks.keys()) | |
| layer_name = sorted_keys[0] | |
| proc = valid_hooks[layer_name] | |
| attention_map = proc.captured_attn_map | |
| attn_avg = attention_map.mean(dim=0).cpu().detach() | |
| split_idx = int(attn_avg.shape[0] / 2) | |
| height = width = int(math.sqrt(split_idx)) | |
| attn_cross = attn_avg[:split_idx, split_idx:] | |
| if inpainting_mask is not None: | |
| mask_resized = inpainting_mask.resize((width, height), resample=Image.NEAREST) | |
| mask_array = np.array(mask_resized).astype(np.float32) | |
| if mask_array.max() > 1.0: | |
| mask_array = mask_array / 255.0 | |
| mask_binary = mask_array > 0.5 | |
| mask_flat = mask_binary.flatten() | |
| num_mask_pixels = int(mask_flat.sum()) | |
| if num_mask_pixels > 0: | |
| attn_mask_region = attn_cross[mask_flat, :] | |
| map_data = attn_mask_region.mean(dim=0).reshape(height, width).numpy() | |
| print(f"Using {num_mask_pixels}/{len(mask_flat)} query positions from mask region") | |
| else: | |
| print("Warning: mask region is empty after binarization, falling back to full attention map") | |
| map_data = attn_cross.mean(dim=0).reshape(height, width).numpy() | |
| else: | |
| map_data = attn_cross.mean(dim=0).reshape(height, width).numpy() | |
| # Normalize to 0–255 | |
| map_data = (map_data - map_data.min()) / (map_data.max() - map_data.min() + 1e-8) | |
| heatmap = (map_data * 255).astype(np.uint8) | |
| # JET pseudo-color mapping | |
| heatmap_img = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) | |
| heatmap_img = cv2.cvtColor(heatmap_img, cv2.COLOR_BGR2RGB) | |
| target_size = (512, 512) | |
| heatmap_img = cv2.resize(heatmap_img, target_size) | |
| # Overlay onto the reference image | |
| if ref_image is not None: | |
| org_img = np.array(ref_image.resize(target_size, resample=Image.BICUBIC)) | |
| org_img = cv2.resize(org_img, target_size) | |
| attn_vis = cv2.addWeighted(org_img, 0.3, heatmap_img, 0.7, 0) | |
| else: | |
| attn_vis = heatmap_img | |
| return Image.fromarray(attn_vis) | |
| class AttentionVisualizer: | |
| def __init__(self, pipe): | |
| self.pipe = pipe | |
| self.hooks = {} | |
| def register_specific_layer(self): | |
| unet = self.pipe.unet | |
| target_block_idx = 3 | |
| target_attn_idx = 2 | |
| try: | |
| block = unet.up_blocks[target_block_idx] | |
| if hasattr(block, "attentions") and len(block.attentions) > target_attn_idx: | |
| attn_module = block.attentions[target_attn_idx] | |
| for k, transformer in enumerate(attn_module.transformer_blocks): | |
| target_attn = transformer.attn1 | |
| layer_name = f"up_blocks.{target_block_idx}.attentions.{target_attn_idx}" | |
| hook_proc = CaptureAttnProcessor() | |
| target_attn.set_processor(hook_proc) | |
| self.hooks[layer_name] = hook_proc | |
| print(f"Successfully registered hook: {layer_name}") | |
| else: | |
| print(f"Error: Layer up_blocks.{target_block_idx}.attentions.{target_attn_idx} does not exist.") | |
| except IndexError: | |
| print(f"Error: up_blocks index {target_block_idx} out of range.") | |
| def reset(self): | |
| """Reset the capture state of all hooks. Should be called before each generation.""" | |
| for proc in self.hooks.values(): | |
| proc.captured_attn_map = None | |
| proc.cnt = 0 | |
| def visualize(self, inpainting_mask=None, ref_image=None): | |
| """Return a PIL Image visualization of the attention map.""" | |
| return visualize_attention_map( | |
| self.hooks, inpainting_mask=inpainting_mask, ref_image=ref_image | |
| ) | |
| # ===================== Model Wrapper ===================== | |
| class ReferencenetInpainting: | |
| def __init__(self, sd_pipe, referencenet, image_encoder_path, checkpoint_path, device): | |
| self.device = device | |
| self.image_encoder_path = image_encoder_path | |
| self.checkpoint_path = checkpoint_path | |
| self.referencenet = referencenet.to(self.device) | |
| self.pipe = sd_pipe.to(self.device) | |
| self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to( | |
| self.device, dtype=torch.float16 | |
| ) | |
| self.clip_image_processor = CLIPImageProcessor() | |
| self.image_proj_model = self.init_proj() | |
| self.load_unet_and_image_proj_and_referencenet() | |
| def init_proj(self): | |
| image_proj_model = LinearResampler( | |
| input_dim=1280, | |
| output_dim=self.pipe.unet.config.cross_attention_dim, | |
| ).to(self.device, dtype=torch.float16) | |
| return image_proj_model | |
| def load_unet_and_image_proj_and_referencenet(self): | |
| state_dict = torch.load(self.checkpoint_path, map_location="cpu") | |
| self.pipe.unet.load_state_dict(state_dict["unet"], strict=False) | |
| self.referencenet.load_state_dict(state_dict["referencenet"], strict=False) | |
| self.image_proj_model.load_state_dict(state_dict["image_proj"]) | |
| def get_image_embeds(self, pil_image=None, clip_image_embeds=None): | |
| if isinstance(pil_image, Image.Image): | |
| pil_image = [pil_image] | |
| clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values | |
| clip_image = clip_image.to(self.device, dtype=torch.float16) | |
| clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] | |
| image_prompt_embeds = self.image_proj_model(clip_image_embeds).to(dtype=torch.float16) | |
| uncond_clip_image_embeds = self.image_encoder( | |
| torch.zeros_like(clip_image), output_hidden_states=True | |
| ).hidden_states[-2] | |
| uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds) | |
| return image_prompt_embeds, uncond_image_prompt_embeds | |
| def generate( | |
| self, | |
| pil_ref_image=None, | |
| pil_background_image=None, | |
| pil_mask_image=None, | |
| num_samples=1, | |
| seed=None, | |
| guidance_scale=7.5, | |
| num_inference_steps=30, | |
| **kwargs, | |
| ): | |
| image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image=pil_ref_image) | |
| bs_embed, seq_len, _ = image_prompt_embeds.shape | |
| image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1) | |
| image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) | |
| uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1) | |
| uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1) | |
| generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None | |
| images = self.pipe( | |
| image=pil_background_image, | |
| mask_image=pil_mask_image, | |
| prompt_embeds=image_prompt_embeds, | |
| negative_prompt_embeds=uncond_image_prompt_embeds, | |
| guidance_scale=guidance_scale, | |
| num_inference_steps=num_inference_steps, | |
| generator=generator, | |
| referencenet=self.referencenet, | |
| ref_image=pil_ref_image, | |
| clip_image_embed=torch.cat([uncond_image_prompt_embeds, image_prompt_embeds], dim=0), | |
| **kwargs, | |
| ).images | |
| return images | |
| # ===================== Model Setup ===================== | |
| parser = argparse.ArgumentParser(description="Gradio Demo") | |
| allow_sd_text_encoder_patterns = ["text_encoder/config.json", "text_encoder/pytorch_model.bin"] | |
| allow_tokenizer_patterns = ["tokenizer/*"] | |
| allow_scheduler_patterns = ["scheduler/*"] | |
| allow_vae_patterns = ["vae/config.json", "vae/diffusion_pytorch_model.bin"] | |
| allow_unet_patterns = ["unet/config.json", "unet/diffusion_pytorch_model.bin"] | |
| allow_sd_patterns = allow_sd_text_encoder_patterns + allow_tokenizer_patterns + allow_scheduler_patterns + allow_vae_patterns + allow_unet_patterns + ["model_index.json"] | |
| sd_model_path = snapshot_download("stable-diffusion-v1-5/stable-diffusion-inpainting", allow_patterns=allow_sd_patterns) | |
| ref_model_path = snapshot_download("stable-diffusion-v1-5/stable-diffusion-v1-5", allow_patterns=allow_sd_patterns) | |
| image_encoder_path = snapshot_download("laion/CLIP-ViT-H-14-laion2B-s32B-b79K", allow_patterns=["config.json", "pytorch_model.bin"]) | |
| checkpoint_path = snapshot_download('LiXiY/ReferenceAnomaly') + "/" + "reference_anomaly_checkponint.bin" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| args = parser.parse_args() | |
| noise_scheduler = DDIMScheduler( | |
| num_train_timesteps=1000, | |
| beta_start=0.00085, | |
| beta_end=0.012, | |
| beta_schedule="scaled_linear", | |
| clip_sample=False, | |
| set_alpha_to_one=False, | |
| steps_offset=1, | |
| ) | |
| pipe = StableDiffusionInpaintPipeline.from_pretrained( | |
| sd_model_path, | |
| torch_dtype=torch.float16, | |
| scheduler=noise_scheduler, | |
| feature_extractor=None, | |
| safety_checker=None | |
| ) | |
| referencenet = ReferenceNet.from_pretrained(ref_model_path, subfolder="unet", feature_extractor=None, safety_checker=None).to(dtype=torch.float16) | |
| reference_anomaly_model = ReferencenetInpainting(pipe, referencenet, image_encoder_path, checkpoint_path, device) | |
| # ===================== Register Attention Hook ===================== | |
| attention_visualizer = AttentionVisualizer(reference_anomaly_model.pipe) | |
| attention_visualizer.register_specific_layer() | |
| # ===================== Example Data ===================== | |
| CANVAS_W, CANVAS_H = 512, 512 | |
| EXAMPLES = [ | |
| ("validation_images/background_image_1.png", "validation_images/ref_image_1.png", "validation_images/inpainting_mask_1.png"), | |
| ("validation_images/background_image_2.png", "validation_images/ref_image_2.png", "validation_images/inpainting_mask_2.png"), | |
| ("validation_images/background_image_3.png", "validation_images/ref_image_3.png", "validation_images/inpainting_mask_3.png"), | |
| ("validation_images/background_image_4.png", "validation_images/ref_image_4.png", "validation_images/inpainting_mask_4.png"), | |
| ] | |
| # ===================== Thumbnail HTML Generation ===================== | |
| def img_to_b64(path, size=(120, 120)): | |
| img = Image.open(path).convert("RGB").resize(size, Image.LANCZOS) | |
| buf = BytesIO() | |
| img.save(buf, format="PNG") | |
| return base64.b64encode(buf.getvalue()).decode() | |
| def build_examples_html(): | |
| row_pairs = [EXAMPLES[0:2], EXAMPLES[2:4]] | |
| html = '<div class="ex-grid">' | |
| for row_idx, row_examples in enumerate(row_pairs): | |
| html += '<div class="ex-grid-row">' | |
| for col_idx, (bg, ref, mask) in enumerate(row_examples): | |
| i = row_idx * 2 + col_idx | |
| bg_b64 = img_to_b64(bg) | |
| ref_b64 = img_to_b64(ref) | |
| mask_b64 = img_to_b64(mask) | |
| html += f''' | |
| <div class="ex-row" onclick="(function(){{var el=document.getElementById('ex_btn_{i}');if(!el)return;var btn=el.querySelector('button')||el;btn.dispatchEvent(new MouseEvent('click',{{bubbles:true,cancelable:true}}));}})()"> | |
| <div class="ex-label">Example {i + 1}</div> | |
| <div class="ex-thumbs"> | |
| <div class="ex-thumb-wrap"> | |
| <img src="data:image/png;base64,{bg_b64}" class="ex-thumb" draggable="false"/> | |
| <span class="ex-thumb-sublabel">Background</span> | |
| </div> | |
| <div class="ex-thumb-wrap"> | |
| <img src="data:image/png;base64,{mask_b64}" class="ex-thumb" draggable="false"/> | |
| <span class="ex-thumb-sublabel">Mask</span> | |
| </div> | |
| <div class="ex-thumb-wrap"> | |
| <img src="data:image/png;base64,{ref_b64}" class="ex-thumb" draggable="false"/> | |
| <span class="ex-thumb-sublabel">Reference</span> | |
| </div> | |
| </div> | |
| </div> | |
| ''' | |
| html += '</div>' | |
| html += '</div>' | |
| return html | |
| # ===================== Utility Functions ===================== | |
| def fit_image_to_canvas(img, canvas_w=CANVAS_W, canvas_h=CANVAS_H): | |
| img_rgba = img.convert("RGBA") | |
| img_rgba.thumbnail((canvas_w, canvas_h), Image.LANCZOS) | |
| canvas = Image.new("RGBA", (canvas_w, canvas_h), (0, 0, 0, 0)) | |
| offset_x = (canvas_w - img_rgba.width) // 2 | |
| offset_y = (canvas_h - img_rgba.height) // 2 | |
| canvas.paste(img_rgba, (offset_x, offset_y)) | |
| return canvas | |
| def extract_mask_from_layers(layers, target_size): | |
| mask = Image.new("L", target_size, 0) | |
| for layer in layers: | |
| if layer is not None: | |
| layer_rgba = layer.convert("RGBA").resize(target_size) | |
| alpha = layer_rgba.split()[3] | |
| alpha_binary = alpha.point(lambda x: 255 if x > 0 else 0) | |
| mask = Image.composite(Image.new("L", target_size, 255), mask, alpha_binary) | |
| return mask | |
| def load_example(idx): | |
| bg_path, ref_path, mask_path = EXAMPLES[idx] | |
| bg = Image.open(bg_path).convert("RGB").resize((CANVAS_W, CANVAS_H)) | |
| ref_img = Image.open(ref_path).convert("RGB").resize((CANVAS_W, CANVAS_H)) | |
| mask = Image.open(mask_path).convert("L").resize((CANVAS_W, CANVAS_H)) | |
| transparent = Image.new("RGBA", (CANVAS_W, CANVAS_H), (0, 0, 0, 0)) | |
| white_solid = Image.new("RGBA", (CANVAS_W, CANVAS_H), (255, 255, 255, 255)) | |
| mask_layer = Image.composite(white_solid, transparent, mask) | |
| composite = Image.alpha_composite(bg.convert("RGBA"), mask_layer) | |
| editor_val = { | |
| "background": bg, | |
| "layers": [mask_layer], | |
| "composite": composite, | |
| } | |
| return editor_val, ref_img | |
| def load_ex1(): | |
| return load_example(0) | |
| def load_ex2(): | |
| return load_example(1) | |
| def load_ex3(): | |
| return load_example(2) | |
| def load_ex4(): | |
| return load_example(3) | |
| # ===================== Generation Function (also returns attention map) ===================== | |
| def run_local(base, ref): | |
| if base is None or ref is None: | |
| return None, None, gr.update(visible=False) | |
| target_size = (CANVAS_W, CANVAS_H) | |
| pil_ref = ref.convert("RGB").resize(target_size) | |
| if not isinstance(base, dict): | |
| return None, None, gr.update(visible=False) | |
| bg_pil = base.get("background") | |
| layers = base.get("layers", []) | |
| if bg_pil is None: | |
| return None, None, gr.update(visible=False) | |
| pil_bg = bg_pil.convert("RGB").resize(target_size) | |
| pil_mask = extract_mask_from_layers(layers, target_size) | |
| if pil_mask.getextrema() == (0, 0): | |
| error_html = """ | |
| <div class="error-overlay" style=" | |
| position: fixed; top: 0; left: 0; width: 100%; height: 100%; | |
| background: rgba(0,0,0,0.5); display: flex; justify-content: center; | |
| align-items: center; z-index: 9999; | |
| "> | |
| <div style=" | |
| background: white; padding: 30px; border-radius: 10px; | |
| text-align: center; font-size: 18px; box-shadow: 0 0 15px rgba(0,0,0,0.3); | |
| "> | |
| <p style="color: red; margin-bottom: 20px;"> | |
| ⚠️ Please draw the anomaly region (mask) on the background image first, or click an example! | |
| </p> | |
| <button onclick="this.closest('.error-overlay').remove()" | |
| style="padding: 8px 20px; cursor: pointer; border: none; | |
| background: #eee; border-radius: 5px;"> | |
| OK | |
| </button> | |
| </div> | |
| </div> | |
| """ | |
| return None, None, gr.update(value=error_html, visible=True) | |
| # Reset attention capture before generation | |
| attention_visualizer.reset() | |
| generated_images = reference_anomaly_model.generate( | |
| pil_ref_image=pil_ref, | |
| pil_background_image=pil_bg, | |
| pil_mask_image=pil_mask, | |
| num_samples=1, | |
| guidance_scale=7.5, | |
| num_inference_steps=25, | |
| seed=42, | |
| ) | |
| result_img = generated_images[0].resize(target_size) | |
| # Generate attention map visualization | |
| attn_img = attention_visualizer.visualize( | |
| inpainting_mask=pil_mask, | |
| ref_image=pil_ref, | |
| ) | |
| return result_img, attn_img, gr.update(visible=False) | |
| # ===================== Combined Client JS (resize + force English) ===================== | |
| # KEY FIX: merge both JS functions into a SINGLE function body instead of | |
| # concatenating two separate function expressions. | |
| COMBINED_JS = """ | |
| function() { | |
| /* ===== Client-Side Instant Resize ===== */ | |
| var MAX_W = """ + str(CANVAS_W) + """; | |
| var MAX_H = """ + str(CANVAS_H) + """; | |
| function resizeInBrowser(file) { | |
| return new Promise(function(resolve) { | |
| var reader = new FileReader(); | |
| reader.onload = function(e) { | |
| var img = new Image(); | |
| img.onload = function() { | |
| if (img.width <= MAX_W && img.height <= MAX_H) { | |
| resolve(null); | |
| return; | |
| } | |
| var ratio = Math.min(MAX_W / img.width, MAX_H / img.height); | |
| var c = document.createElement('canvas'); | |
| c.width = Math.round(img.width * ratio); | |
| c.height = Math.round(img.height * ratio); | |
| c.getContext('2d').drawImage(img, 0, 0, c.width, c.height); | |
| c.toBlob(function(blob) { | |
| resolve(blob ? new File([blob], file.name, {type: 'image/png'}) : null); | |
| }, 'image/png'); | |
| }; | |
| img.onerror = function() { resolve(null); }; | |
| img.src = e.target.result; | |
| }; | |
| reader.onerror = function() { resolve(null); }; | |
| reader.readAsDataURL(file); | |
| }); | |
| } | |
| function hookInput(inp) { | |
| if (inp._resizeHooked) return; | |
| inp._resizeHooked = true; | |
| inp._skipResize = false; | |
| inp.addEventListener('change', function(e) { | |
| if (inp._skipResize) { inp._skipResize = false; return; } | |
| var file = inp.files && inp.files[0]; | |
| if (!file || !file.type || file.type.indexOf('image/') !== 0) return; | |
| e.stopImmediatePropagation(); | |
| e.stopPropagation(); | |
| resizeInBrowser(file).then(function(resized) { | |
| if (resized) { | |
| var dt = new DataTransfer(); | |
| dt.items.add(resized); | |
| inp.files = dt.files; | |
| } | |
| inp._skipResize = true; | |
| inp.dispatchEvent(new Event('change', {bubbles: true})); | |
| }); | |
| }, true); | |
| } | |
| function scan() { | |
| var inputs = document.querySelectorAll('.input-row input[type="file"]'); | |
| for (var i = 0; i < inputs.length; i++) hookInput(inputs[i]); | |
| } | |
| new MutationObserver(scan).observe(document.body, {childList: true, subtree: true}); | |
| scan(); | |
| /* ===== Force English UI Labels ===== */ | |
| var zh2en = { | |
| '将图像拖放到此处或点击上传': 'Drag image here or click to upload', | |
| '拖放文件到这里': 'Drag file here', | |
| '点击上传': 'Click to upload', | |
| '或点击上传': 'or click to upload', | |
| '上传图片': 'Upload image', | |
| '粘贴图片或URL': 'Paste image or URL', | |
| '清空': 'Clear', | |
| '编辑': 'Edit', | |
| '撤销': 'Undo', | |
| '重做': 'Redo', | |
| '缩放': 'Zoom', | |
| '画笔': 'Brush', | |
| '橡皮擦': 'Eraser', | |
| '清除图层': 'Clear layers', | |
| '图像编辑器': 'Image Editor', | |
| '生成': 'Generate', | |
| '正在运行...': 'Running...', | |
| '提交': 'Submit', | |
| }; | |
| function translateNode(node) { | |
| if (node.nodeType === Node.TEXT_NODE) { | |
| var text = node.textContent; | |
| for (var zh in zh2en) { | |
| if (text.indexOf(zh) !== -1) { | |
| text = text.split(zh).join(zh2en[zh]); | |
| } | |
| } | |
| if (text !== node.textContent) node.textContent = text; | |
| } | |
| } | |
| function walkAndTranslate(root) { | |
| var walker = document.createTreeWalker(root, NodeFilter.SHOW_TEXT, null, false); | |
| var node; | |
| while (node = walker.nextNode()) translateNode(node); | |
| } | |
| function translateAttributes(root) { | |
| root.querySelectorAll('[placeholder]').forEach(function(el) { | |
| var ph = el.getAttribute('placeholder'); | |
| for (var zh in zh2en) { | |
| if (ph.indexOf(zh) !== -1) ph = ph.split(zh).join(zh2en[zh]); | |
| } | |
| el.setAttribute('placeholder', ph); | |
| }); | |
| root.querySelectorAll('[title]').forEach(function(el) { | |
| var t = el.getAttribute('title'); | |
| for (var zh in zh2en) { | |
| if (t.indexOf(zh) !== -1) t = t.split(zh).join(zh2en[zh]); | |
| } | |
| el.setAttribute('title', t); | |
| }); | |
| } | |
| function runTranslate() { | |
| walkAndTranslate(document.body); | |
| translateAttributes(document.body); | |
| } | |
| var translateObserver = new MutationObserver(function(mutations) { | |
| for (var m = 0; m < mutations.length; m++) { | |
| var added = mutations[m].addedNodes; | |
| for (var n = 0; n < added.length; n++) { | |
| if (added[n].nodeType === Node.ELEMENT_NODE) { | |
| walkAndTranslate(added[n]); | |
| translateAttributes(added[n]); | |
| } | |
| } | |
| } | |
| }); | |
| translateObserver.observe(document.body, { childList: true, subtree: true }); | |
| runTranslate(); | |
| setInterval(runTranslate, 2000); | |
| } | |
| """ | |
| # ===================== Gradio UI ===================== | |
| with gr.Blocks(css=""" | |
| .input-row { | |
| overflow: visible !important; | |
| } | |
| .input-row .gr-image-editor { | |
| overflow: hidden !important; | |
| } | |
| .input-row .gr-image-editor .image-container, | |
| .input-row .gr-image-editor .canvas-container, | |
| .input-row .gr-image-editor canvas { | |
| max-width: 100% !important; | |
| max-height: 100% !important; | |
| object-fit: contain !important; | |
| } | |
| .ex-section-header { | |
| display: flex; | |
| align-items: center; | |
| gap: 10px; | |
| margin: 28px 0 14px 0; | |
| justify-content: center; | |
| } | |
| .ex-section-header::before { | |
| content: ''; | |
| flex: 1; | |
| height: 1px; | |
| max-width: 180px; | |
| background: #e5e7eb; | |
| } | |
| .ex-section-header::after { | |
| content: ''; | |
| flex: 1; | |
| height: 1px; | |
| max-width: 180px; | |
| background: #e5e7eb; | |
| } | |
| .ex-container { | |
| display: flex; | |
| flex-direction: column; | |
| align-items: center; | |
| padding-bottom: 20px; | |
| } | |
| .ex-grid { | |
| display: flex; | |
| flex-direction: column; | |
| align-items: center; | |
| gap: 12px; | |
| padding-bottom: 20px; | |
| } | |
| .ex-grid-row { | |
| display: flex; | |
| gap: 20px; | |
| justify-content: center; | |
| flex-wrap: wrap; | |
| } | |
| .ex-row { | |
| display: flex; | |
| align-items: center; | |
| gap: 20px; | |
| padding: 14px 28px; | |
| border: 2px solid #e5e7eb; | |
| border-radius: 12px; | |
| cursor: pointer; | |
| transition: all 0.25s ease; | |
| background: #ffffff; | |
| user-select: none; | |
| width: fit-content; | |
| } | |
| .ex-row:hover { | |
| border-color: #3b82f6; | |
| background: #f0f7ff; | |
| box-shadow: 0 4px 18px rgba(59, 130, 246, 0.15); | |
| transform: translateY(-2px); | |
| } | |
| .ex-row:active { | |
| transform: translateY(0); | |
| box-shadow: 0 2px 8px rgba(59, 130, 246, 0.2); | |
| } | |
| .ex-label { | |
| font-weight: 700; | |
| font-size: 15px; | |
| min-width: 62px; | |
| color: #1e40af; | |
| letter-spacing: 0.02em; | |
| } | |
| .ex-thumbs { | |
| display: flex; | |
| gap: 14px; | |
| } | |
| .ex-thumb-wrap { | |
| display: flex; | |
| flex-direction: column; | |
| align-items: center; | |
| gap: 6px; | |
| } | |
| .ex-thumb { | |
| width: 110px; | |
| height: 110px; | |
| object-fit: cover; | |
| border-radius: 8px; | |
| border: 2px solid #e5e7eb; | |
| transition: all 0.25s ease; | |
| pointer-events: none; | |
| } | |
| .ex-row:hover .ex-thumb { | |
| border-color: #93c5fd; | |
| } | |
| .ex-thumb-sublabel { | |
| font-size: 12px; | |
| color: #6b7280; | |
| font-weight: 500; | |
| } | |
| """, js=COMBINED_JS) as demo: | |
| gr.Markdown( | |
| "<h1 style='text-align: center;'>Reference-Based Anomaly Image Generation</h1>" | |
| "<h3 style='text-align: center;'>Generate anomaly images similar to the reference anomaly on normal images</h3>" | |
| "<h3 style='text-align: center;'>Github: https://github.com/huan-yin/reference_anomaly_generation</h3>" | |
| ) | |
| gr.Markdown( | |
| """ | |
| **Instructions:** | |
| 1. Upload a background image (normal object), then use the brush tool below the image to mark the region where you want to generate an anomaly (mask), and upload a reference image (reference anomaly). | |
| 2. Or click any row of thumbnails in the "Examples" section below to automatically load a background + mask + reference image. | |
| 3. Click the "Generate" button, and the result will be displayed below. | |
| """ | |
| ) | |
| with gr.Row(elem_classes="input-row"): | |
| base = gr.ImageEditor( | |
| label="Background Image (Normal Object)", | |
| type="pil", | |
| width=420, | |
| height=450, | |
| canvas_size=(CANVAS_W, CANVAS_H), | |
| sources=["upload"], | |
| brush=gr.Brush( | |
| default_size=15, | |
| default_color="#FFFFFF", | |
| color_mode="fixed", | |
| colors=["#FFFFFF"], | |
| ), | |
| ) | |
| ref = gr.Image( | |
| label="Reference Image (Reference Anomaly)", | |
| sources=["upload"], | |
| type="pil", | |
| width=420, | |
| height=380, | |
| ) | |
| with gr.Row(): | |
| gen_btn = gr.Button("Generate", variant="primary") | |
| # ==================== Generation Result + Attention Map Side-by-Side ==================== | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| output_image = gr.Image( | |
| label="Generated Result", | |
| interactive=False, | |
| ) | |
| with gr.Column(scale=1): | |
| attention_map_output = gr.Image( | |
| label="Attention Map (Attention Visualization for Reference Anomaly)", | |
| interactive=False, | |
| ) | |
| with gr.Row(): | |
| error_dialog = gr.HTML(visible=False) | |
| gr.HTML('<div class="ex-section-header"><span style="font-weight:700;font-size:16px;color:#374151;">Examples (click to load background + mask + reference image)</span></div>') | |
| ex_btn0 = gr.Button("Example 1", visible=False, elem_id="ex_btn_0") | |
| ex_btn1 = gr.Button("Example 2", visible=False, elem_id="ex_btn_1") | |
| ex_btn2 = gr.Button("Example 3", visible=False, elem_id="ex_btn_2") | |
| ex_btn3 = gr.Button("Example 4", visible=False, elem_id="ex_btn_3") | |
| gr.HTML('<div class="ex-container">' + build_examples_html() + '</div>') | |
| # ==================== Event Bindings ==================== | |
| ex_btn0.click(fn=load_ex1, outputs=[base, ref]) | |
| ex_btn1.click(fn=load_ex2, outputs=[base, ref]) | |
| ex_btn2.click(fn=load_ex3, outputs=[base, ref]) | |
| ex_btn3.click(fn=load_ex4, outputs=[base, ref]) | |
| gen_btn.click( | |
| fn=run_local, | |
| inputs=[base, ref], | |
| outputs=[output_image, attention_map_output, error_dialog], | |
| ) | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |