""" Gradio Demo for FLUX VAE Image Restoration Model 支持自定义分辨率、采样器(Euler ODE / SDE Euler-Maruyama)和推理步数 """ import os import math import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import gradio as gr from PIL import Image import timm from diffusers import AutoencoderKL from safetensors.torch import load_file from huggingface_hub import hf_hub_download # ----------------------------------------------------------------------------- # 配置 # ----------------------------------------------------------------------------- MODEL_PATH = "model.safetensors" # 模型权重路径 MODEL_REPO_ID = "telecomadm1145/img_restore" # 替换为你的 HF 仓库 ID MODEL_FILENAME = "model.safetensors" # HF 仓库中的文件名 VAE_ID = "advokat/AnimePro-FLUX" #"black-forest-labs/FLUX.1-schnell" VAE_SUBFOLDER = "vae" HF_TOKEN = os.environ.get("HF_TOKEN") or None # 模型参数(与训练时一致) LATENT_CHANNELS = 16 VAE_SCALE_FACTOR = 8 DIT_HIDDEN_SIZE = 1024 DIT_DEPTH = 16 DIT_NUM_HEADS = 4 PATCH_SIZE = 2 DINO_MODEL_NAME = 'vit_base_patch16_dinov3.lvd1689m' IMG_SIZE = 384 # 默认训练尺寸 # VAE 统计量(如果有缓存的话加载,否则使用默认值) DEFAULT_VAE_MEAN = torch.zeros(LATENT_CHANNELS) DEFAULT_VAE_STD = torch.ones(LATENT_CHANNELS) # ----------------------------------------------------------------------------- # 模型定义(与训练代码完全一致) # ----------------------------------------------------------------------------- def modulate(x, shift, scale): return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) class TimestepEmbedder(nn.Module): def __init__(self, hidden_size, frequency_embedding_size=256): super().__init__() self.mlp = nn.Sequential( nn.Linear(frequency_embedding_size, hidden_size, bias=True), nn.SiLU(), nn.Linear(hidden_size, hidden_size, bias=True), ) self.frequency_embedding_size = frequency_embedding_size @staticmethod def timestep_embedding(t, dim, max_period=10000): t = t * 1000.0 half = dim // 2 freqs = torch.exp(-math.log(max_period) * torch.arange(0, half, dtype=torch.float32, device=t.device) / half) args = t[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) return embedding def forward(self, t): t_freq = self.timestep_embedding(t, self.frequency_embedding_size) return self.mlp(t_freq) class RotaryEmbedding2D(nn.Module): def __init__(self, dim, max_h=64, max_w=64): super().__init__() self.dim = dim dim_h = dim // 2 dim_w = dim - dim_h inv_freq_h = 1.0 / (10000 ** (torch.arange(0, dim_h, 2).float() / dim_h)) inv_freq_w = 1.0 / (10000 ** (torch.arange(0, dim_w, 2).float() / dim_w)) self.register_buffer("inv_freq_h", inv_freq_h) self.register_buffer("inv_freq_w", inv_freq_w) self._set_cos_sin_cache(max_h, max_w) def _set_cos_sin_cache(self, h, w): t_h = torch.arange(h).type_as(self.inv_freq_h) freqs_h = torch.outer(t_h, self.inv_freq_h) emb_h = torch.cat((freqs_h, freqs_h), dim=-1) t_w = torch.arange(w).type_as(self.inv_freq_w) freqs_w = torch.outer(t_w, self.inv_freq_w) emb_w = torch.cat((freqs_w, freqs_w), dim=-1) emb_h_broad = emb_h.unsqueeze(1).repeat(1, w, 1) emb_w_broad = emb_w.unsqueeze(0).repeat(h, 1, 1) emb = torch.cat((emb_h_broad, emb_w_broad), dim=-1).flatten(0, 1) self.register_buffer("cos_cached", emb.cos().unsqueeze(0).unsqueeze(0), persistent=False) self.register_buffer("sin_cached", emb.sin().unsqueeze(0).unsqueeze(0), persistent=False) def forward(self, x, h, w): return self.cos_cached[:, :, : h * w, :].to(x.dtype), self.sin_cached[:, :, : h * w, :].to(x.dtype) def apply_rotary_pos_emb(q, k, cos, sin): def rotate_half(x): x1, x2 = x.chunk(2, dim=-1) return torch.cat((-x2, x1), dim=-1) return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) class SwiGLU(nn.Module): def __init__(self, hidden_size: int, mlp_ratio: float = 4.0): super().__init__() mlp_hidden = int(hidden_size * mlp_ratio * 2 / 3) mlp_hidden = ((mlp_hidden + 63) // 64) * 64 self.w1 = nn.Linear(hidden_size, mlp_hidden, bias=False) self.w2 = nn.Linear(hidden_size, mlp_hidden, bias=False) self.w3 = nn.Linear(mlp_hidden, hidden_size, bias=False) def forward(self, x): return self.w3(F.silu(self.w1(x)) * self.w2(x)) class DiTBlock(nn.Module): def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float = 4.0): super().__init__() self.hidden_size = hidden_size self.num_heads = num_heads self.head_dim = hidden_size // num_heads self.norm1_latent = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.norm1_cond = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.qkv = nn.Linear(hidden_size, hidden_size * 3, bias=False) self.proj = nn.Linear(hidden_size, hidden_size, bias=False) self.norm2_latent = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.norm2_cond = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.q_norm = nn.LayerNorm(self.head_dim, eps=1e-6) self.k_norm = nn.LayerNorm(self.head_dim, eps=1e-6) self.mlp = SwiGLU(hidden_size, mlp_ratio) self.adaLN_latent = nn.Sequential( nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=False) ) self.adaLN_cond = nn.Sequential( nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=False) ) def forward(self, x, t_emb, rope_cos, rope_sin, num_latents): B, L, D = x.shape num_cond = L - num_latents shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = \ self.adaLN_latent(t_emb).chunk(6, dim=-1) shift_msa_c, scale_msa_c, gate_msa_c, shift_mlp_c, scale_mlp_c, gate_mlp_c = \ self.adaLN_cond(t_emb).chunk(6, dim=-1) x_lat, x_cond = x[:, :num_latents], x[:, num_latents:] x_lat_norm = modulate(self.norm1_latent(x_lat), shift_msa, scale_msa) x_cond_norm = modulate(self.norm1_cond(x_cond), shift_msa_c, scale_msa_c) x_norm = torch.cat([x_lat_norm, x_cond_norm], dim=1) qkv = self.qkv(x_norm) q, k, v = qkv.reshape(B, L, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4).unbind(0) q = self.q_norm(q) k = self.k_norm(k) q, k = apply_rotary_pos_emb(q, k, rope_cos, rope_sin) q, k = q.to(v.dtype), k.to(v.dtype) x_attn = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0) x_attn = x_attn.transpose(1, 2).reshape(B, L, D) x_attn = self.proj(x_attn) x_attn_lat, x_attn_cond = x_attn[:, :num_latents], x_attn[:, num_latents:] x_lat = x_lat + gate_msa.unsqueeze(1) * x_attn_lat x_cond = x_cond + gate_msa_c.unsqueeze(1) * x_attn_cond x_lat_norm = modulate(self.norm2_latent(x_lat), shift_mlp, scale_mlp) x_cond_norm = modulate(self.norm2_cond(x_cond), shift_mlp_c, scale_mlp_c) x_norm = torch.cat([x_lat_norm, x_cond_norm], dim=1) mlp_out = self.mlp(x_norm) mlp_lat, mlp_cond = mlp_out[:, :num_latents], mlp_out[:, num_latents:] x_lat = x_lat + gate_mlp.unsqueeze(1) * mlp_lat x_cond = x_cond + gate_mlp_c.unsqueeze(1) * mlp_cond return torch.cat([x_lat, x_cond], dim=1) class FinalLayer(nn.Module): def __init__(self, hidden_size, patch_size, out_channels): super().__init__() self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) def forward(self, x, c): shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) x = modulate(self.norm_final(x), shift, scale) return self.linear(x) class FluxLatentDINOFlow(nn.Module): def __init__( self, img_size=256, patch_size=2, latent_channels=4, hidden_size=1152, depth=28, num_heads=16, dino_model_name='vit_base_patch14_dinov2.lvd142m', ): super().__init__() self.img_size = img_size self.patch_size = patch_size self.latent_channels = latent_channels self.hidden_size = hidden_size self.num_heads = num_heads self.latent_size = img_size // 8 self.grid_size = self.latent_size // patch_size self.num_patches = self.grid_size ** 2 print(f"Loading DINO: {dino_model_name}") self.dino = timm.create_model(dino_model_name, pretrained=True, img_size=img_size, num_classes=0) for p in self.dino.parameters(): p.requires_grad = False self.dino.eval() self.dino_adapter = nn.Sequential( nn.Conv2d(self.dino.embed_dim, hidden_size, kernel_size=1), nn.SiLU(), nn.Conv2d(hidden_size, hidden_size, kernel_size=3, padding=1), ) self.pixel_adapter = nn.Sequential( nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1), nn.SiLU(), nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), nn.SiLU(), nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1), nn.SiLU(), nn.Conv2d(256, hidden_size, kernel_size=patch_size, stride=patch_size), ) self.x_embedder = nn.Linear(patch_size * patch_size * latent_channels, hidden_size) self.t_embedder = TimestepEmbedder(hidden_size) self.rope = RotaryEmbedding2D(dim=hidden_size // num_heads, max_h=self.grid_size, max_w=self.grid_size) self.blocks = nn.ModuleList([DiTBlock(hidden_size, num_heads) for _ in range(depth)]) self.final_layer = FinalLayer(hidden_size, patch_size, latent_channels) self.type_emb_target = nn.Parameter(torch.randn(1, 1, hidden_size) * 0.02) self.type_emb_pixel = nn.Parameter(torch.randn(1, 1, hidden_size) * 0.02) self.type_emb_dino = nn.Parameter(torch.randn(1, 1, hidden_size) * 0.02) self.initialize_weights() # 新增缓存 self._cached_dino_map = None self._cached_lq_hash = None # 可选:缓存输入哈希 def initialize_weights(self): for name, m in self.named_modules(): if "dino" in name: continue if isinstance(m, (nn.Linear, nn.Conv2d)): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.zeros_(m.bias) nn.init.zeros_(self.final_layer.linear.weight) nn.init.zeros_(self.final_layer.linear.bias) def patchify(self, x): p = self.patch_size h, w = x.shape[2] // p, x.shape[3] // p x = x.reshape(x.shape[0], x.shape[1], h, p, w, p) x = torch.einsum('nchpwq->nhwpqc', x) x = x.reshape(x.shape[0], h * w, -1) return x def unpatchify(self, x): p = self.patch_size c = self.latent_channels h = w = int(x.shape[1] ** 0.5) x = x.reshape(x.shape[0], h, w, p, p, c) x = torch.einsum('nhwpqc->nchpwq', x) return x.reshape(x.shape[0], c, h * p, w * p) def forward(self, x_t_latent, t, lq_img): B = x_t_latent.shape[0] x_patches = self.patchify(x_t_latent) x_tokens = self.x_embedder(x_patches) x_tokens = x_tokens + self.type_emb_target pixel_tokens = self.pixel_adapter(lq_img) pixel_tokens = pixel_tokens.flatten(2).transpose(1, 2) pixel_tokens = pixel_tokens + self.type_emb_pixel # 计算输入 hash lq_hash = hash(lq_img.data_ptr()) # 简单用指针做哈希,也可用 tensor.sum().item() 更精确 if self._cached_dino_map is None or self._cached_lq_hash != lq_hash: print("recalculating hash...") # 只在缓存不存在或输入变化时计算 DINO with torch.no_grad(): mean = torch.tensor([0.485, 0.456, 0.406], device=lq_img.device).view(1, 3, 1, 1) std = torch.tensor([0.229, 0.224, 0.225], device=lq_img.device).view(1, 3, 1, 1) dino_in = (lq_img * 0.5 + 0.5 - mean) / std dino_feats = self.dino.forward_features(dino_in) if getattr(self.dino, "num_prefix_tokens", 0) > 0: dino_feats = dino_feats[:, self.dino.num_prefix_tokens:] d_h = d_w = int(dino_feats.shape[1] ** 0.5) dino_map = dino_feats.transpose(1, 2).reshape(B, -1, d_h, d_w) dino_map_resized = F.interpolate(dino_map, size=(self.grid_size, self.grid_size), mode='bilinear', align_corners=False) dino_tokens = self.dino_adapter(dino_map_resized) dino_tokens = dino_tokens.flatten(2).transpose(1, 2) dino_tokens = dino_tokens + self.type_emb_dino # 更新缓存 self._cached_dino_map = dino_tokens self._cached_lq_hash = lq_hash else: dino_tokens = self._cached_dino_map tokens = torch.cat([x_tokens, pixel_tokens, dino_tokens], dim=1) t_emb = self.t_embedder(t) cos_base, sin_base = self.rope(tokens, self.grid_size, self.grid_size) cos = torch.cat([cos_base] * 3, dim=2) sin = torch.cat([sin_base] * 3, dim=2) for block in self.blocks: tokens = block(tokens, t_emb, cos, sin, num_latents=self.num_patches) out_tokens = tokens[:, :self.num_patches] out_patches = self.final_layer(out_tokens, t_emb) out_latents = self.unpatchify(out_patches) return out_latents # ----------------------------------------------------------------------------- # VAE 管理器 # ----------------------------------------------------------------------------- class VAEManager: def __init__(self, model_id, subfolder, device, mean=None, std=None): print(f"Loading Flux VAE from {model_id}...") self.vae = AutoencoderKL.from_pretrained(model_id, subfolder=subfolder, token=HF_TOKEN) self.device = device self.vae.to(self.device).eval() self.vae.requires_grad_(False) if mean is None: mean = DEFAULT_VAE_MEAN if std is None: std = DEFAULT_VAE_STD self.register_stats(mean, std) def register_stats(self, mean, std): self.shift = mean.to(self.device).view(1, -1, 1, 1) self.scale = (1.0 / (std.to(self.device) + 1e-6)).view(1, -1, 1, 1) print(f"VAE Stats Registered") @torch.no_grad() def encode(self, pixels): latents = self.vae.encode(pixels).latent_dist.mode() latents = (latents - self.shift) * self.scale return latents @torch.no_grad() def decode(self, latents): latents = latents / self.scale + self.shift return self.vae.decode(latents).sample # ----------------------------------------------------------------------------- # 采样器(添加进度回调支持) # ----------------------------------------------------------------------------- class FlowMatchingSampler: """Flow Matching 采样器,支持 ODE 和 SDE""" def __init__(self, model, vae_mgr, device): self.model = model self.vae_mgr = vae_mgr self.device = device @torch.no_grad() def sample_euler_ode(self, lq, steps, progress_callback=None): """Euler ODE 采样器(确定性)""" B = lq.shape[0] H_lat = lq.shape[2] // VAE_SCALE_FACTOR W_lat = lq.shape[3] // VAE_SCALE_FACTOR x = torch.randn(B, LATENT_CHANNELS, H_lat, W_lat, device=self.device) dt = 1.0 / steps for i in range(steps): t = torch.full((B,), i / steps, device=self.device, dtype=torch.float32) with torch.cuda.amp.autocast(dtype=torch.float16): v = self.model(x, t, lq) x = x + v * dt # 进度回调 if progress_callback is not None: progress_callback(i + 1, steps, f"Euler ODE 采样中... {i+1}/{steps}") restored = self.vae_mgr.decode(x) return torch.clamp(restored, -1, 1) @torch.no_grad() def sample_sde_euler_maruyama(self, lq, steps, sigma=0.1, progress_callback=None): """SDE Euler-Maruyama 采样器(随机性)""" B = lq.shape[0] H_lat = lq.shape[2] // VAE_SCALE_FACTOR W_lat = lq.shape[3] // VAE_SCALE_FACTOR x = torch.randn(B, LATENT_CHANNELS, H_lat, W_lat, device=self.device) dt = 1.0 / steps sqrt_dt = math.sqrt(dt) for i in range(steps): t = torch.full((B,), i / steps, device=self.device, dtype=torch.float32) with torch.cuda.amp.autocast(dtype=torch.float16): v = self.model(x, t, lq) noise = torch.randn_like(x) x = x + v * dt + sigma * sqrt_dt * noise # 进度回调 if progress_callback is not None: progress_callback(i + 1, steps, f"SDE Euler-Maruyama 采样中... {i+1}/{steps}") restored = self.vae_mgr.decode(x) return torch.clamp(restored, -1, 1) @torch.no_grad() def sample_sde_reverse_diffusion(self, lq, steps, sigma_schedule="linear", progress_callback=None): """逆向 SDE 采样器""" B = lq.shape[0] H_lat = lq.shape[2] // VAE_SCALE_FACTOR W_lat = lq.shape[3] // VAE_SCALE_FACTOR x = torch.randn(B, LATENT_CHANNELS, H_lat, W_lat, device=self.device) dt = 1.0 / steps for i in range(steps): t_val = i / steps t = torch.full((B,), t_val, device=self.device, dtype=torch.float32) with torch.cuda.amp.autocast(dtype=torch.float16): v = self.model(x, t, lq) if sigma_schedule == "linear": sigma = 0.5 * (1 - t_val) elif sigma_schedule == "cosine": sigma = 0.5 * math.cos(t_val * math.pi / 2) else: sigma = 0.1 noise = torch.randn_like(x) if i < steps - 1 else 0 x = x + v * dt + sigma * math.sqrt(dt) * noise # 进度回调 if progress_callback is not None: progress_callback(i + 1, steps, f"SDE Reverse Diffusion 采样中... {i+1}/{steps}") restored = self.vae_mgr.decode(x) return torch.clamp(restored, -1, 1) @torch.no_grad() def sample_heun_ode(self, lq, steps, progress_callback=None): """Heun's Method (二阶 Runge-Kutta) ODE 采样器""" B = lq.shape[0] H_lat = lq.shape[2] // VAE_SCALE_FACTOR W_lat = lq.shape[3] // VAE_SCALE_FACTOR x = torch.randn(B, LATENT_CHANNELS, H_lat, W_lat, device=self.device) dt = 1.0 / steps for i in range(steps): t = torch.full((B,), i / steps, device=self.device, dtype=torch.float32) t_next = torch.full((B,), (i + 1) / steps, device=self.device, dtype=torch.float32) with torch.cuda.amp.autocast(dtype=torch.float16): v1 = self.model(x, t, lq) x_pred = x + v1 * dt if i < steps - 1: v2 = self.model(x_pred, t_next, lq) x = x + 0.5 * (v1 + v2) * dt else: x = x_pred # 进度回调 if progress_callback is not None: progress_callback(i + 1, steps, f"Heun ODE 采样中... {i+1}/{steps}") restored = self.vae_mgr.decode(x) return torch.clamp(restored, -1, 1) def sample(self, lq, steps, sampler_type="euler_ode", progress_callback=None, **kwargs): """统一采样接口""" if sampler_type == "euler_ode": return self.sample_euler_ode(lq, steps, progress_callback=progress_callback) elif sampler_type == "sde_euler_maruyama": sigma = kwargs.get("sigma", 0.1) return self.sample_sde_euler_maruyama(lq, steps, sigma=sigma, progress_callback=progress_callback) elif sampler_type == "sde_reverse": sigma_schedule = kwargs.get("sigma_schedule", "linear") return self.sample_sde_reverse_diffusion(lq, steps, sigma_schedule=sigma_schedule, progress_callback=progress_callback) elif sampler_type == "heun_ode": return self.sample_heun_ode(lq, steps, progress_callback=progress_callback) else: raise ValueError(f"Unknown sampler type: {sampler_type}") # ----------------------------------------------------------------------------- # 模型加载 # ----------------------------------------------------------------------------- class ImageRestorer: def __init__(self, model_path=None, device="cuda", repo_id=None, filename="model.safetensors"): """ Args: model_path: 本地模型路径,如果为 None 则从 HF 下载 device: 运行设备 repo_id: Hugging Face 仓库 ID,例如 "username/model-name" filename: HF 仓库中的模型文件名 """ self.device = torch.device(device if torch.cuda.is_available() else "cpu") print(f"Using device: {self.device}") # ========== 从 Hugging Face 下载模型 ========== if model_path is None or not os.path.exists(model_path): if repo_id is not None: print(f"Downloading model from Hugging Face: {repo_id}/{filename}") try: model_path = hf_hub_download( repo_id=repo_id, filename=filename, token=HF_TOKEN, cache_dir="./hf_cache" # 可选:指定缓存目录 ) print(f"Model downloaded to: {model_path}") except Exception as e: raise RuntimeError(f"Failed to download model from HF: {e}") else: raise FileNotFoundError( f"Model not found at {model_path} and no repo_id provided" ) # ========== 同时下载 VAE 统计量(如果有的话)========== vae_mean = DEFAULT_VAE_MEAN vae_std = DEFAULT_VAE_STD # 先尝试从 HF 下载 vae_stats.npy if repo_id is not None: try: vae_stats_path = hf_hub_download( repo_id=repo_id, filename="vae_stats.npy", token=HF_TOKEN, cache_dir="./hf_cache" ) stats = np.load(vae_stats_path, allow_pickle=True).item() vae_mean = torch.from_numpy(stats['mean']) vae_std = torch.from_numpy(stats['std']) print("Loaded VAE stats from HF repo") except Exception: print("No vae_stats.npy in HF repo, checking local...") # 尝试本地 vae_stats.npy if os.path.exists("vae_stats.npy"): try: stats = np.load("vae_stats.npy", allow_pickle=True).item() vae_mean = torch.from_numpy(stats['mean']) vae_std = torch.from_numpy(stats['std']) print("Loaded cached VAE stats from local") except Exception as e: print(f"Failed to load local VAE stats: {e}") # 加载 VAE self.vae_mgr = VAEManager(VAE_ID, VAE_SUBFOLDER, self.device, vae_mean, vae_std) # 加载模型 print(f"Loading model from {model_path}...") self.model = FluxLatentDINOFlow( hidden_size=DIT_HIDDEN_SIZE, depth=DIT_DEPTH, num_heads=DIT_NUM_HEADS, patch_size=PATCH_SIZE, latent_channels=LATENT_CHANNELS, img_size=IMG_SIZE, dino_model_name=DINO_MODEL_NAME ).to(self.device) state_dict = load_file(model_path) self.model.load_state_dict(state_dict, strict=False) print("Model loaded successfully") self.model.eval() # 创建采样器 self.sampler = FlowMatchingSampler(self.model, self.vae_mgr, self.device) def preprocess(self, image: Image.Image, target_size: int) -> torch.Tensor: """预处理图像""" # 确保尺寸是 VAE_SCALE_FACTOR * PATCH_SIZE 的倍数 min_unit = VAE_SCALE_FACTOR * PATCH_SIZE target_size = (target_size // min_unit) * min_unit if target_size < min_unit: target_size = min_unit # Resize image = image.convert("RGB") image = image.resize((target_size, target_size), Image.Resampling.BICUBIC) # To tensor arr = np.array(image).astype(np.float32) / 127.5 - 1.0 tensor = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0) return tensor.to(self.device) def postprocess(self, tensor: torch.Tensor) -> Image.Image: """后处理张量为图像""" tensor = tensor.squeeze(0).cpu() tensor = (tensor * 0.5 + 0.5).clamp(0, 1) arr = (tensor.permute(1, 2, 0).numpy() * 255).astype(np.uint8) return Image.fromarray(arr) @torch.no_grad() def restore( self, image: Image.Image, resolution: int = 384, steps: int = 25, sampler_type: str = "euler_ode", sigma: float = 0.1, seed: int = -1, progress_callback=None # 添加进度回调参数 ) -> Image.Image: """执行图像修复""" # 设置随机种子 if seed >= 0: torch.manual_seed(seed) np.random.seed(seed) # 预处理 if progress_callback: progress_callback(0, steps + 2, "预处理图像...") lq = self.preprocess(image, resolution) # 动态调整模型参数以适应不同分辨率 self._adjust_model_for_resolution(resolution) if progress_callback: progress_callback(1, steps + 2, "开始采样...") # 包装进度回调,调整偏移 def wrapped_progress(current, total, desc): if progress_callback: progress_callback(current + 1, steps + 2, desc) # 采样 restored_tensor = self.sampler.sample( lq, steps, sampler_type=sampler_type, sigma=sigma, progress_callback=wrapped_progress ) if progress_callback: progress_callback(steps + 2, steps + 2, "后处理...") # 后处理 return self.postprocess(restored_tensor) def _adjust_model_for_resolution(self, resolution: int): """动态调整模型以适应不同分辨率""" min_unit = VAE_SCALE_FACTOR * PATCH_SIZE resolution = (resolution // min_unit) * min_unit new_latent_size = resolution // VAE_SCALE_FACTOR new_grid_size = new_latent_size // PATCH_SIZE if new_grid_size != self.model.grid_size: print(f"Adjusting model for resolution {resolution} (grid: {new_grid_size})") self.model.latent_size = new_latent_size self.model.grid_size = new_grid_size self.model.num_patches = new_grid_size ** 2 # 重新计算 RoPE self.model.rope._set_cos_sin_cache(new_grid_size, new_grid_size) def create_demo(restorer: ImageRestorer): """创建 Gradio Demo""" def process_image( image, resolution, steps, sampler, sigma, seed, progress=gr.Progress(track_tqdm=True) # 添加 progress 参数 ): if image is None: return None # 采样器映射 sampler_map = { "Euler ODE (确定性)": "euler_ode", "Heun ODE (二阶,更准确)": "heun_ode", "SDE Euler-Maruyama (随机性)": "sde_euler_maruyama", "SDE Reverse Diffusion (逆向扩散)": "sde_reverse", } sampler_type = sampler_map.get(sampler, "euler_ode") # 创建进度回调 def progress_callback(current, total, desc): progress(current / total, desc=desc) try: result = restorer.restore( image, resolution=int(resolution), steps=int(steps), sampler_type=sampler_type, sigma=float(sigma), seed=int(seed), progress_callback=progress_callback # 传入进度回调 ) return result except Exception as e: print(f"Error: {e}") import traceback traceback.print_exc() return None with gr.Blocks(title="Image Restoration Demo", css=""" .progress-bar { height: 20px !important; } """) as demo: gr.Markdown(""" # 🖼️ FLUX VAE 图像修复 Demo 使用 Flow Matching + DiT + DINO 的图像修复模型。上传一张图像,选择参数后点击"修复"按钮。 > 💡 提示:进度条会显示当前处理步骤,如果有多人同时使用会显示排队状态。 """) with gr.Row(): with gr.Column(scale=1): input_image = gr.Image(type="pil", label="输入图像") with gr.Group(): resolution = gr.Slider( minimum=128, maximum=1024, value=384, step=16, label="分辨率", info="图像会被 resize 到此分辨率(会自动对齐到 16 的倍数)" ) steps = gr.Slider( minimum=5, maximum=100, value=25, step=1, label="推理步数", info="更多步数 = 更好质量,但更慢" ) sampler = gr.Dropdown( choices=[ "Euler ODE (确定性)", "Heun ODE (二阶,更准确)", "SDE Euler-Maruyama (随机性)", "SDE Reverse Diffusion (逆向扩散)", ], value="Euler ODE (确定性)", label="采样器", info="ODE 是确定性的,SDE 会添加随机噪声" ) sigma = gr.Slider( minimum=0.0, maximum=1.0, value=0.1, step=0.01, label="SDE 噪声强度 (sigma)", info="仅对 SDE 采样器有效,越大随机性越强" ) seed = gr.Number( value=-1, label="随机种子", info="-1 表示随机种子" ) submit_btn = gr.Button("🚀 开始修复", variant="primary") # 添加队列状态提示 gr.Markdown(""" ⏳ 如果按钮显示"排队中...",说明有其他用户正在使用,请稍候。 """) with gr.Column(scale=1): output_image = gr.Image(type="pil", label="修复结果") # 绑定点击事件 submit_btn.click( fn=process_image, inputs=[input_image, resolution, steps, sampler, sigma, seed], outputs=output_image, show_progress="full" # 显示完整进度条 ) gr.Markdown(""" ### 📝 说明 **采样器类型**: - **Euler ODE**: 标准 Flow Matching 采样,确定性,速度快 - **Heun ODE**: 二阶 Runge-Kutta 方法,更准确但需要双倍计算 - **SDE Euler-Maruyama**: 添加随机噪声的 SDE 采样,可以增加多样性 - **SDE Reverse Diffusion**: 使用衰减噪声的逆向扩散 SDE **参数建议**: - 一般情况:Euler ODE, 25 步 - 更高质量:Heun ODE, 30-50 步 - 需要多样性/创意修复:SDE 采样器, sigma=0.1-0.2 """) return demo # ----------------------------------------------------------------------------- # 主函数 # ----------------------------------------------------------------------------- def main(): restorer = ImageRestorer(model_path=MODEL_PATH,repo_id=MODEL_REPO_ID) create_demo(restorer).queue( max_size=10, default_concurrency_limit=2 ).launch() if __name__ == "__main__": main()