Spaces:
Sleeping
Sleeping
| """ | |
| Gradio Demo for FLUX VAE Image Restoration Model | |
| 支持自定义分辨率、采样器(Euler ODE / SDE Euler-Maruyama)和推理步数 | |
| """ | |
| import os | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import gradio as gr | |
| from PIL import Image | |
| import timm | |
| from diffusers import AutoencoderKL | |
| from safetensors.torch import load_file | |
| from huggingface_hub import hf_hub_download | |
| # ----------------------------------------------------------------------------- | |
| # 配置 | |
| # ----------------------------------------------------------------------------- | |
| MODEL_PATH = "model.safetensors" # 模型权重路径 | |
| MODEL_REPO_ID = "telecomadm1145/img_restore" # 替换为你的 HF 仓库 ID | |
| MODEL_FILENAME = "model.safetensors" # HF 仓库中的文件名 | |
| VAE_ID = "advokat/AnimePro-FLUX" #"black-forest-labs/FLUX.1-schnell" | |
| VAE_SUBFOLDER = "vae" | |
| HF_TOKEN = os.environ.get("HF_TOKEN") or None | |
| # 模型参数(与训练时一致) | |
| LATENT_CHANNELS = 16 | |
| VAE_SCALE_FACTOR = 8 | |
| DIT_HIDDEN_SIZE = 1024 | |
| DIT_DEPTH = 16 | |
| DIT_NUM_HEADS = 4 | |
| PATCH_SIZE = 2 | |
| DINO_MODEL_NAME = 'vit_base_patch16_dinov3.lvd1689m' | |
| IMG_SIZE = 384 # 默认训练尺寸 | |
| # VAE 统计量(如果有缓存的话加载,否则使用默认值) | |
| DEFAULT_VAE_MEAN = torch.zeros(LATENT_CHANNELS) | |
| DEFAULT_VAE_STD = torch.ones(LATENT_CHANNELS) | |
| # ----------------------------------------------------------------------------- | |
| # 模型定义(与训练代码完全一致) | |
| # ----------------------------------------------------------------------------- | |
| def modulate(x, shift, scale): | |
| return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) | |
| class TimestepEmbedder(nn.Module): | |
| def __init__(self, hidden_size, frequency_embedding_size=256): | |
| super().__init__() | |
| self.mlp = nn.Sequential( | |
| nn.Linear(frequency_embedding_size, hidden_size, bias=True), | |
| nn.SiLU(), | |
| nn.Linear(hidden_size, hidden_size, bias=True), | |
| ) | |
| self.frequency_embedding_size = frequency_embedding_size | |
| def timestep_embedding(t, dim, max_period=10000): | |
| t = t * 1000.0 | |
| half = dim // 2 | |
| freqs = torch.exp(-math.log(max_period) * torch.arange(0, half, dtype=torch.float32, device=t.device) / half) | |
| args = t[:, None].float() * freqs[None] | |
| embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) | |
| if dim % 2: | |
| embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) | |
| return embedding | |
| def forward(self, t): | |
| t_freq = self.timestep_embedding(t, self.frequency_embedding_size) | |
| return self.mlp(t_freq) | |
| class RotaryEmbedding2D(nn.Module): | |
| def __init__(self, dim, max_h=64, max_w=64): | |
| super().__init__() | |
| self.dim = dim | |
| dim_h = dim // 2 | |
| dim_w = dim - dim_h | |
| inv_freq_h = 1.0 / (10000 ** (torch.arange(0, dim_h, 2).float() / dim_h)) | |
| inv_freq_w = 1.0 / (10000 ** (torch.arange(0, dim_w, 2).float() / dim_w)) | |
| self.register_buffer("inv_freq_h", inv_freq_h) | |
| self.register_buffer("inv_freq_w", inv_freq_w) | |
| self._set_cos_sin_cache(max_h, max_w) | |
| def _set_cos_sin_cache(self, h, w): | |
| t_h = torch.arange(h).type_as(self.inv_freq_h) | |
| freqs_h = torch.outer(t_h, self.inv_freq_h) | |
| emb_h = torch.cat((freqs_h, freqs_h), dim=-1) | |
| t_w = torch.arange(w).type_as(self.inv_freq_w) | |
| freqs_w = torch.outer(t_w, self.inv_freq_w) | |
| emb_w = torch.cat((freqs_w, freqs_w), dim=-1) | |
| emb_h_broad = emb_h.unsqueeze(1).repeat(1, w, 1) | |
| emb_w_broad = emb_w.unsqueeze(0).repeat(h, 1, 1) | |
| emb = torch.cat((emb_h_broad, emb_w_broad), dim=-1).flatten(0, 1) | |
| self.register_buffer("cos_cached", emb.cos().unsqueeze(0).unsqueeze(0), persistent=False) | |
| self.register_buffer("sin_cached", emb.sin().unsqueeze(0).unsqueeze(0), persistent=False) | |
| def forward(self, x, h, w): | |
| return self.cos_cached[:, :, : h * w, :].to(x.dtype), self.sin_cached[:, :, : h * w, :].to(x.dtype) | |
| def apply_rotary_pos_emb(q, k, cos, sin): | |
| def rotate_half(x): | |
| x1, x2 = x.chunk(2, dim=-1) | |
| return torch.cat((-x2, x1), dim=-1) | |
| return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin) | |
| class SwiGLU(nn.Module): | |
| def __init__(self, hidden_size: int, mlp_ratio: float = 4.0): | |
| super().__init__() | |
| mlp_hidden = int(hidden_size * mlp_ratio * 2 / 3) | |
| mlp_hidden = ((mlp_hidden + 63) // 64) * 64 | |
| self.w1 = nn.Linear(hidden_size, mlp_hidden, bias=False) | |
| self.w2 = nn.Linear(hidden_size, mlp_hidden, bias=False) | |
| self.w3 = nn.Linear(mlp_hidden, hidden_size, bias=False) | |
| def forward(self, x): | |
| return self.w3(F.silu(self.w1(x)) * self.w2(x)) | |
| class DiTBlock(nn.Module): | |
| def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float = 4.0): | |
| super().__init__() | |
| self.hidden_size = hidden_size | |
| self.num_heads = num_heads | |
| self.head_dim = hidden_size // num_heads | |
| self.norm1_latent = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) | |
| self.norm1_cond = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) | |
| self.qkv = nn.Linear(hidden_size, hidden_size * 3, bias=False) | |
| self.proj = nn.Linear(hidden_size, hidden_size, bias=False) | |
| self.norm2_latent = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) | |
| self.norm2_cond = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) | |
| self.q_norm = nn.LayerNorm(self.head_dim, eps=1e-6) | |
| self.k_norm = nn.LayerNorm(self.head_dim, eps=1e-6) | |
| self.mlp = SwiGLU(hidden_size, mlp_ratio) | |
| self.adaLN_latent = nn.Sequential( | |
| nn.SiLU(), | |
| nn.Linear(hidden_size, 6 * hidden_size, bias=False) | |
| ) | |
| self.adaLN_cond = nn.Sequential( | |
| nn.SiLU(), | |
| nn.Linear(hidden_size, 6 * hidden_size, bias=False) | |
| ) | |
| def forward(self, x, t_emb, rope_cos, rope_sin, num_latents): | |
| B, L, D = x.shape | |
| num_cond = L - num_latents | |
| shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = \ | |
| self.adaLN_latent(t_emb).chunk(6, dim=-1) | |
| shift_msa_c, scale_msa_c, gate_msa_c, shift_mlp_c, scale_mlp_c, gate_mlp_c = \ | |
| self.adaLN_cond(t_emb).chunk(6, dim=-1) | |
| x_lat, x_cond = x[:, :num_latents], x[:, num_latents:] | |
| x_lat_norm = modulate(self.norm1_latent(x_lat), shift_msa, scale_msa) | |
| x_cond_norm = modulate(self.norm1_cond(x_cond), shift_msa_c, scale_msa_c) | |
| x_norm = torch.cat([x_lat_norm, x_cond_norm], dim=1) | |
| qkv = self.qkv(x_norm) | |
| q, k, v = qkv.reshape(B, L, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4).unbind(0) | |
| q = self.q_norm(q) | |
| k = self.k_norm(k) | |
| q, k = apply_rotary_pos_emb(q, k, rope_cos, rope_sin) | |
| q, k = q.to(v.dtype), k.to(v.dtype) | |
| x_attn = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0) | |
| x_attn = x_attn.transpose(1, 2).reshape(B, L, D) | |
| x_attn = self.proj(x_attn) | |
| x_attn_lat, x_attn_cond = x_attn[:, :num_latents], x_attn[:, num_latents:] | |
| x_lat = x_lat + gate_msa.unsqueeze(1) * x_attn_lat | |
| x_cond = x_cond + gate_msa_c.unsqueeze(1) * x_attn_cond | |
| x_lat_norm = modulate(self.norm2_latent(x_lat), shift_mlp, scale_mlp) | |
| x_cond_norm = modulate(self.norm2_cond(x_cond), shift_mlp_c, scale_mlp_c) | |
| x_norm = torch.cat([x_lat_norm, x_cond_norm], dim=1) | |
| mlp_out = self.mlp(x_norm) | |
| mlp_lat, mlp_cond = mlp_out[:, :num_latents], mlp_out[:, num_latents:] | |
| x_lat = x_lat + gate_mlp.unsqueeze(1) * mlp_lat | |
| x_cond = x_cond + gate_mlp_c.unsqueeze(1) * mlp_cond | |
| return torch.cat([x_lat, x_cond], dim=1) | |
| class FinalLayer(nn.Module): | |
| def __init__(self, hidden_size, patch_size, out_channels): | |
| super().__init__() | |
| self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) | |
| self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) | |
| self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) | |
| def forward(self, x, c): | |
| shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) | |
| x = modulate(self.norm_final(x), shift, scale) | |
| return self.linear(x) | |
| class FluxLatentDINOFlow(nn.Module): | |
| def __init__( | |
| self, | |
| img_size=256, | |
| patch_size=2, | |
| latent_channels=4, | |
| hidden_size=1152, | |
| depth=28, | |
| num_heads=16, | |
| dino_model_name='vit_base_patch14_dinov2.lvd142m', | |
| ): | |
| super().__init__() | |
| self.img_size = img_size | |
| self.patch_size = patch_size | |
| self.latent_channels = latent_channels | |
| self.hidden_size = hidden_size | |
| self.num_heads = num_heads | |
| self.latent_size = img_size // 8 | |
| self.grid_size = self.latent_size // patch_size | |
| self.num_patches = self.grid_size ** 2 | |
| print(f"Loading DINO: {dino_model_name}") | |
| self.dino = timm.create_model(dino_model_name, pretrained=True, img_size=img_size, num_classes=0) | |
| for p in self.dino.parameters(): | |
| p.requires_grad = False | |
| self.dino.eval() | |
| self.dino_adapter = nn.Sequential( | |
| nn.Conv2d(self.dino.embed_dim, hidden_size, kernel_size=1), | |
| nn.SiLU(), | |
| nn.Conv2d(hidden_size, hidden_size, kernel_size=3, padding=1), | |
| ) | |
| self.pixel_adapter = nn.Sequential( | |
| nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1), | |
| nn.SiLU(), | |
| nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), | |
| nn.SiLU(), | |
| nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1), | |
| nn.SiLU(), | |
| nn.Conv2d(256, hidden_size, kernel_size=patch_size, stride=patch_size), | |
| ) | |
| self.x_embedder = nn.Linear(patch_size * patch_size * latent_channels, hidden_size) | |
| self.t_embedder = TimestepEmbedder(hidden_size) | |
| self.rope = RotaryEmbedding2D(dim=hidden_size // num_heads, max_h=self.grid_size, max_w=self.grid_size) | |
| self.blocks = nn.ModuleList([DiTBlock(hidden_size, num_heads) for _ in range(depth)]) | |
| self.final_layer = FinalLayer(hidden_size, patch_size, latent_channels) | |
| self.type_emb_target = nn.Parameter(torch.randn(1, 1, hidden_size) * 0.02) | |
| self.type_emb_pixel = nn.Parameter(torch.randn(1, 1, hidden_size) * 0.02) | |
| self.type_emb_dino = nn.Parameter(torch.randn(1, 1, hidden_size) * 0.02) | |
| self.initialize_weights() | |
| # 新增缓存 | |
| self._cached_dino_map = None | |
| self._cached_lq_hash = None # 可选:缓存输入哈希 | |
| def initialize_weights(self): | |
| for name, m in self.named_modules(): | |
| if "dino" in name: | |
| continue | |
| if isinstance(m, (nn.Linear, nn.Conv2d)): | |
| nn.init.xavier_uniform_(m.weight) | |
| if m.bias is not None: | |
| nn.init.zeros_(m.bias) | |
| nn.init.zeros_(self.final_layer.linear.weight) | |
| nn.init.zeros_(self.final_layer.linear.bias) | |
| def patchify(self, x): | |
| p = self.patch_size | |
| h, w = x.shape[2] // p, x.shape[3] // p | |
| x = x.reshape(x.shape[0], x.shape[1], h, p, w, p) | |
| x = torch.einsum('nchpwq->nhwpqc', x) | |
| x = x.reshape(x.shape[0], h * w, -1) | |
| return x | |
| def unpatchify(self, x): | |
| p = self.patch_size | |
| c = self.latent_channels | |
| h = w = int(x.shape[1] ** 0.5) | |
| x = x.reshape(x.shape[0], h, w, p, p, c) | |
| x = torch.einsum('nhwpqc->nchpwq', x) | |
| return x.reshape(x.shape[0], c, h * p, w * p) | |
| def forward(self, x_t_latent, t, lq_img): | |
| B = x_t_latent.shape[0] | |
| x_patches = self.patchify(x_t_latent) | |
| x_tokens = self.x_embedder(x_patches) | |
| x_tokens = x_tokens + self.type_emb_target | |
| pixel_tokens = self.pixel_adapter(lq_img) | |
| pixel_tokens = pixel_tokens.flatten(2).transpose(1, 2) | |
| pixel_tokens = pixel_tokens + self.type_emb_pixel | |
| # 计算输入 hash | |
| lq_hash = hash(lq_img.data_ptr()) # 简单用指针做哈希,也可用 tensor.sum().item() 更精确 | |
| if self._cached_dino_map is None or self._cached_lq_hash != lq_hash: | |
| print("recalculating hash...") | |
| # 只在缓存不存在或输入变化时计算 DINO | |
| with torch.no_grad(): | |
| mean = torch.tensor([0.485, 0.456, 0.406], device=lq_img.device).view(1, 3, 1, 1) | |
| std = torch.tensor([0.229, 0.224, 0.225], device=lq_img.device).view(1, 3, 1, 1) | |
| dino_in = (lq_img * 0.5 + 0.5 - mean) / std | |
| dino_feats = self.dino.forward_features(dino_in) | |
| if getattr(self.dino, "num_prefix_tokens", 0) > 0: | |
| dino_feats = dino_feats[:, self.dino.num_prefix_tokens:] | |
| d_h = d_w = int(dino_feats.shape[1] ** 0.5) | |
| dino_map = dino_feats.transpose(1, 2).reshape(B, -1, d_h, d_w) | |
| dino_map_resized = F.interpolate(dino_map, size=(self.grid_size, self.grid_size), mode='bilinear', align_corners=False) | |
| dino_tokens = self.dino_adapter(dino_map_resized) | |
| dino_tokens = dino_tokens.flatten(2).transpose(1, 2) | |
| dino_tokens = dino_tokens + self.type_emb_dino | |
| # 更新缓存 | |
| self._cached_dino_map = dino_tokens | |
| self._cached_lq_hash = lq_hash | |
| else: | |
| dino_tokens = self._cached_dino_map | |
| tokens = torch.cat([x_tokens, pixel_tokens, dino_tokens], dim=1) | |
| t_emb = self.t_embedder(t) | |
| cos_base, sin_base = self.rope(tokens, self.grid_size, self.grid_size) | |
| cos = torch.cat([cos_base] * 3, dim=2) | |
| sin = torch.cat([sin_base] * 3, dim=2) | |
| for block in self.blocks: | |
| tokens = block(tokens, t_emb, cos, sin, num_latents=self.num_patches) | |
| out_tokens = tokens[:, :self.num_patches] | |
| out_patches = self.final_layer(out_tokens, t_emb) | |
| out_latents = self.unpatchify(out_patches) | |
| return out_latents | |
| # ----------------------------------------------------------------------------- | |
| # VAE 管理器 | |
| # ----------------------------------------------------------------------------- | |
| class VAEManager: | |
| def __init__(self, model_id, subfolder, device, mean=None, std=None): | |
| print(f"Loading Flux VAE from {model_id}...") | |
| self.vae = AutoencoderKL.from_pretrained(model_id, subfolder=subfolder, token=HF_TOKEN) | |
| self.device = device | |
| self.vae.to(self.device).eval() | |
| self.vae.requires_grad_(False) | |
| if mean is None: | |
| mean = DEFAULT_VAE_MEAN | |
| if std is None: | |
| std = DEFAULT_VAE_STD | |
| self.register_stats(mean, std) | |
| def register_stats(self, mean, std): | |
| self.shift = mean.to(self.device).view(1, -1, 1, 1) | |
| self.scale = (1.0 / (std.to(self.device) + 1e-6)).view(1, -1, 1, 1) | |
| print(f"VAE Stats Registered") | |
| def encode(self, pixels): | |
| latents = self.vae.encode(pixels).latent_dist.mode() | |
| latents = (latents - self.shift) * self.scale | |
| return latents | |
| def decode(self, latents): | |
| latents = latents / self.scale + self.shift | |
| return self.vae.decode(latents).sample | |
| # ----------------------------------------------------------------------------- | |
| # 采样器(添加进度回调支持) | |
| # ----------------------------------------------------------------------------- | |
| class FlowMatchingSampler: | |
| """Flow Matching 采样器,支持 ODE 和 SDE""" | |
| def __init__(self, model, vae_mgr, device): | |
| self.model = model | |
| self.vae_mgr = vae_mgr | |
| self.device = device | |
| def sample_euler_ode(self, lq, steps, progress_callback=None): | |
| """Euler ODE 采样器(确定性)""" | |
| B = lq.shape[0] | |
| H_lat = lq.shape[2] // VAE_SCALE_FACTOR | |
| W_lat = lq.shape[3] // VAE_SCALE_FACTOR | |
| x = torch.randn(B, LATENT_CHANNELS, H_lat, W_lat, device=self.device) | |
| dt = 1.0 / steps | |
| for i in range(steps): | |
| t = torch.full((B,), i / steps, device=self.device, dtype=torch.float32) | |
| with torch.cuda.amp.autocast(dtype=torch.float16): | |
| v = self.model(x, t, lq) | |
| x = x + v * dt | |
| # 进度回调 | |
| if progress_callback is not None: | |
| progress_callback(i + 1, steps, f"Euler ODE 采样中... {i+1}/{steps}") | |
| restored = self.vae_mgr.decode(x) | |
| return torch.clamp(restored, -1, 1) | |
| def sample_sde_euler_maruyama(self, lq, steps, sigma=0.1, progress_callback=None): | |
| """SDE Euler-Maruyama 采样器(随机性)""" | |
| B = lq.shape[0] | |
| H_lat = lq.shape[2] // VAE_SCALE_FACTOR | |
| W_lat = lq.shape[3] // VAE_SCALE_FACTOR | |
| x = torch.randn(B, LATENT_CHANNELS, H_lat, W_lat, device=self.device) | |
| dt = 1.0 / steps | |
| sqrt_dt = math.sqrt(dt) | |
| for i in range(steps): | |
| t = torch.full((B,), i / steps, device=self.device, dtype=torch.float32) | |
| with torch.cuda.amp.autocast(dtype=torch.float16): | |
| v = self.model(x, t, lq) | |
| noise = torch.randn_like(x) | |
| x = x + v * dt + sigma * sqrt_dt * noise | |
| # 进度回调 | |
| if progress_callback is not None: | |
| progress_callback(i + 1, steps, f"SDE Euler-Maruyama 采样中... {i+1}/{steps}") | |
| restored = self.vae_mgr.decode(x) | |
| return torch.clamp(restored, -1, 1) | |
| def sample_sde_reverse_diffusion(self, lq, steps, sigma_schedule="linear", progress_callback=None): | |
| """逆向 SDE 采样器""" | |
| B = lq.shape[0] | |
| H_lat = lq.shape[2] // VAE_SCALE_FACTOR | |
| W_lat = lq.shape[3] // VAE_SCALE_FACTOR | |
| x = torch.randn(B, LATENT_CHANNELS, H_lat, W_lat, device=self.device) | |
| dt = 1.0 / steps | |
| for i in range(steps): | |
| t_val = i / steps | |
| t = torch.full((B,), t_val, device=self.device, dtype=torch.float32) | |
| with torch.cuda.amp.autocast(dtype=torch.float16): | |
| v = self.model(x, t, lq) | |
| if sigma_schedule == "linear": | |
| sigma = 0.5 * (1 - t_val) | |
| elif sigma_schedule == "cosine": | |
| sigma = 0.5 * math.cos(t_val * math.pi / 2) | |
| else: | |
| sigma = 0.1 | |
| noise = torch.randn_like(x) if i < steps - 1 else 0 | |
| x = x + v * dt + sigma * math.sqrt(dt) * noise | |
| # 进度回调 | |
| if progress_callback is not None: | |
| progress_callback(i + 1, steps, f"SDE Reverse Diffusion 采样中... {i+1}/{steps}") | |
| restored = self.vae_mgr.decode(x) | |
| return torch.clamp(restored, -1, 1) | |
| def sample_heun_ode(self, lq, steps, progress_callback=None): | |
| """Heun's Method (二阶 Runge-Kutta) ODE 采样器""" | |
| B = lq.shape[0] | |
| H_lat = lq.shape[2] // VAE_SCALE_FACTOR | |
| W_lat = lq.shape[3] // VAE_SCALE_FACTOR | |
| x = torch.randn(B, LATENT_CHANNELS, H_lat, W_lat, device=self.device) | |
| dt = 1.0 / steps | |
| for i in range(steps): | |
| t = torch.full((B,), i / steps, device=self.device, dtype=torch.float32) | |
| t_next = torch.full((B,), (i + 1) / steps, device=self.device, dtype=torch.float32) | |
| with torch.cuda.amp.autocast(dtype=torch.float16): | |
| v1 = self.model(x, t, lq) | |
| x_pred = x + v1 * dt | |
| if i < steps - 1: | |
| v2 = self.model(x_pred, t_next, lq) | |
| x = x + 0.5 * (v1 + v2) * dt | |
| else: | |
| x = x_pred | |
| # 进度回调 | |
| if progress_callback is not None: | |
| progress_callback(i + 1, steps, f"Heun ODE 采样中... {i+1}/{steps}") | |
| restored = self.vae_mgr.decode(x) | |
| return torch.clamp(restored, -1, 1) | |
| def sample(self, lq, steps, sampler_type="euler_ode", progress_callback=None, **kwargs): | |
| """统一采样接口""" | |
| if sampler_type == "euler_ode": | |
| return self.sample_euler_ode(lq, steps, progress_callback=progress_callback) | |
| elif sampler_type == "sde_euler_maruyama": | |
| sigma = kwargs.get("sigma", 0.1) | |
| return self.sample_sde_euler_maruyama(lq, steps, sigma=sigma, progress_callback=progress_callback) | |
| elif sampler_type == "sde_reverse": | |
| sigma_schedule = kwargs.get("sigma_schedule", "linear") | |
| return self.sample_sde_reverse_diffusion(lq, steps, sigma_schedule=sigma_schedule, progress_callback=progress_callback) | |
| elif sampler_type == "heun_ode": | |
| return self.sample_heun_ode(lq, steps, progress_callback=progress_callback) | |
| else: | |
| raise ValueError(f"Unknown sampler type: {sampler_type}") | |
| # ----------------------------------------------------------------------------- | |
| # 模型加载 | |
| # ----------------------------------------------------------------------------- | |
| class ImageRestorer: | |
| def __init__(self, model_path=None, device="cuda", | |
| repo_id=None, filename="model.safetensors"): | |
| """ | |
| Args: | |
| model_path: 本地模型路径,如果为 None 则从 HF 下载 | |
| device: 运行设备 | |
| repo_id: Hugging Face 仓库 ID,例如 "username/model-name" | |
| filename: HF 仓库中的模型文件名 | |
| """ | |
| self.device = torch.device(device if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {self.device}") | |
| # ========== 从 Hugging Face 下载模型 ========== | |
| if model_path is None or not os.path.exists(model_path): | |
| if repo_id is not None: | |
| print(f"Downloading model from Hugging Face: {repo_id}/{filename}") | |
| try: | |
| model_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename=filename, | |
| token=HF_TOKEN, | |
| cache_dir="./hf_cache" # 可选:指定缓存目录 | |
| ) | |
| print(f"Model downloaded to: {model_path}") | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to download model from HF: {e}") | |
| else: | |
| raise FileNotFoundError( | |
| f"Model not found at {model_path} and no repo_id provided" | |
| ) | |
| # ========== 同时下载 VAE 统计量(如果有的话)========== | |
| vae_mean = DEFAULT_VAE_MEAN | |
| vae_std = DEFAULT_VAE_STD | |
| # 先尝试从 HF 下载 vae_stats.npy | |
| if repo_id is not None: | |
| try: | |
| vae_stats_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename="vae_stats.npy", | |
| token=HF_TOKEN, | |
| cache_dir="./hf_cache" | |
| ) | |
| stats = np.load(vae_stats_path, allow_pickle=True).item() | |
| vae_mean = torch.from_numpy(stats['mean']) | |
| vae_std = torch.from_numpy(stats['std']) | |
| print("Loaded VAE stats from HF repo") | |
| except Exception: | |
| print("No vae_stats.npy in HF repo, checking local...") | |
| # 尝试本地 vae_stats.npy | |
| if os.path.exists("vae_stats.npy"): | |
| try: | |
| stats = np.load("vae_stats.npy", allow_pickle=True).item() | |
| vae_mean = torch.from_numpy(stats['mean']) | |
| vae_std = torch.from_numpy(stats['std']) | |
| print("Loaded cached VAE stats from local") | |
| except Exception as e: | |
| print(f"Failed to load local VAE stats: {e}") | |
| # 加载 VAE | |
| self.vae_mgr = VAEManager(VAE_ID, VAE_SUBFOLDER, self.device, vae_mean, vae_std) | |
| # 加载模型 | |
| print(f"Loading model from {model_path}...") | |
| self.model = FluxLatentDINOFlow( | |
| hidden_size=DIT_HIDDEN_SIZE, | |
| depth=DIT_DEPTH, | |
| num_heads=DIT_NUM_HEADS, | |
| patch_size=PATCH_SIZE, | |
| latent_channels=LATENT_CHANNELS, | |
| img_size=IMG_SIZE, | |
| dino_model_name=DINO_MODEL_NAME | |
| ).to(self.device) | |
| state_dict = load_file(model_path) | |
| self.model.load_state_dict(state_dict, strict=False) | |
| print("Model loaded successfully") | |
| self.model.eval() | |
| # 创建采样器 | |
| self.sampler = FlowMatchingSampler(self.model, self.vae_mgr, self.device) | |
| def preprocess(self, image: Image.Image, target_size: int) -> torch.Tensor: | |
| """预处理图像""" | |
| # 确保尺寸是 VAE_SCALE_FACTOR * PATCH_SIZE 的倍数 | |
| min_unit = VAE_SCALE_FACTOR * PATCH_SIZE | |
| target_size = (target_size // min_unit) * min_unit | |
| if target_size < min_unit: | |
| target_size = min_unit | |
| # Resize | |
| image = image.convert("RGB") | |
| image = image.resize((target_size, target_size), Image.Resampling.BICUBIC) | |
| # To tensor | |
| arr = np.array(image).astype(np.float32) / 127.5 - 1.0 | |
| tensor = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0) | |
| return tensor.to(self.device) | |
| def postprocess(self, tensor: torch.Tensor) -> Image.Image: | |
| """后处理张量为图像""" | |
| tensor = tensor.squeeze(0).cpu() | |
| tensor = (tensor * 0.5 + 0.5).clamp(0, 1) | |
| arr = (tensor.permute(1, 2, 0).numpy() * 255).astype(np.uint8) | |
| return Image.fromarray(arr) | |
| def restore( | |
| self, | |
| image: Image.Image, | |
| resolution: int = 384, | |
| steps: int = 25, | |
| sampler_type: str = "euler_ode", | |
| sigma: float = 0.1, | |
| seed: int = -1, | |
| progress_callback=None # 添加进度回调参数 | |
| ) -> Image.Image: | |
| """执行图像修复""" | |
| # 设置随机种子 | |
| if seed >= 0: | |
| torch.manual_seed(seed) | |
| np.random.seed(seed) | |
| # 预处理 | |
| if progress_callback: | |
| progress_callback(0, steps + 2, "预处理图像...") | |
| lq = self.preprocess(image, resolution) | |
| # 动态调整模型参数以适应不同分辨率 | |
| self._adjust_model_for_resolution(resolution) | |
| if progress_callback: | |
| progress_callback(1, steps + 2, "开始采样...") | |
| # 包装进度回调,调整偏移 | |
| def wrapped_progress(current, total, desc): | |
| if progress_callback: | |
| progress_callback(current + 1, steps + 2, desc) | |
| # 采样 | |
| restored_tensor = self.sampler.sample( | |
| lq, steps, | |
| sampler_type=sampler_type, | |
| sigma=sigma, | |
| progress_callback=wrapped_progress | |
| ) | |
| if progress_callback: | |
| progress_callback(steps + 2, steps + 2, "后处理...") | |
| # 后处理 | |
| return self.postprocess(restored_tensor) | |
| def _adjust_model_for_resolution(self, resolution: int): | |
| """动态调整模型以适应不同分辨率""" | |
| min_unit = VAE_SCALE_FACTOR * PATCH_SIZE | |
| resolution = (resolution // min_unit) * min_unit | |
| new_latent_size = resolution // VAE_SCALE_FACTOR | |
| new_grid_size = new_latent_size // PATCH_SIZE | |
| if new_grid_size != self.model.grid_size: | |
| print(f"Adjusting model for resolution {resolution} (grid: {new_grid_size})") | |
| self.model.latent_size = new_latent_size | |
| self.model.grid_size = new_grid_size | |
| self.model.num_patches = new_grid_size ** 2 | |
| # 重新计算 RoPE | |
| self.model.rope._set_cos_sin_cache(new_grid_size, new_grid_size) | |
| def create_demo(restorer: ImageRestorer): | |
| """创建 Gradio Demo""" | |
| def process_image( | |
| image, | |
| resolution, | |
| steps, | |
| sampler, | |
| sigma, | |
| seed, | |
| progress=gr.Progress(track_tqdm=True) # 添加 progress 参数 | |
| ): | |
| if image is None: | |
| return None | |
| # 采样器映射 | |
| sampler_map = { | |
| "Euler ODE (确定性)": "euler_ode", | |
| "Heun ODE (二阶,更准确)": "heun_ode", | |
| "SDE Euler-Maruyama (随机性)": "sde_euler_maruyama", | |
| "SDE Reverse Diffusion (逆向扩散)": "sde_reverse", | |
| } | |
| sampler_type = sampler_map.get(sampler, "euler_ode") | |
| # 创建进度回调 | |
| def progress_callback(current, total, desc): | |
| progress(current / total, desc=desc) | |
| try: | |
| result = restorer.restore( | |
| image, | |
| resolution=int(resolution), | |
| steps=int(steps), | |
| sampler_type=sampler_type, | |
| sigma=float(sigma), | |
| seed=int(seed), | |
| progress_callback=progress_callback # 传入进度回调 | |
| ) | |
| return result | |
| except Exception as e: | |
| print(f"Error: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return None | |
| with gr.Blocks(title="Image Restoration Demo", css=""" | |
| .progress-bar { | |
| height: 20px !important; | |
| } | |
| """) as demo: | |
| gr.Markdown(""" | |
| # 🖼️ FLUX VAE 图像修复 Demo | |
| 使用 Flow Matching + DiT + DINO 的图像修复模型。上传一张图像,选择参数后点击"修复"按钮。 | |
| > 💡 提示:进度条会显示当前处理步骤,如果有多人同时使用会显示排队状态。 | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| input_image = gr.Image(type="pil", label="输入图像") | |
| with gr.Group(): | |
| resolution = gr.Slider( | |
| minimum=128, | |
| maximum=1024, | |
| value=384, | |
| step=16, | |
| label="分辨率", | |
| info="图像会被 resize 到此分辨率(会自动对齐到 16 的倍数)" | |
| ) | |
| steps = gr.Slider( | |
| minimum=5, | |
| maximum=100, | |
| value=25, | |
| step=1, | |
| label="推理步数", | |
| info="更多步数 = 更好质量,但更慢" | |
| ) | |
| sampler = gr.Dropdown( | |
| choices=[ | |
| "Euler ODE (确定性)", | |
| "Heun ODE (二阶,更准确)", | |
| "SDE Euler-Maruyama (随机性)", | |
| "SDE Reverse Diffusion (逆向扩散)", | |
| ], | |
| value="Euler ODE (确定性)", | |
| label="采样器", | |
| info="ODE 是确定性的,SDE 会添加随机噪声" | |
| ) | |
| sigma = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.1, | |
| step=0.01, | |
| label="SDE 噪声强度 (sigma)", | |
| info="仅对 SDE 采样器有效,越大随机性越强" | |
| ) | |
| seed = gr.Number( | |
| value=-1, | |
| label="随机种子", | |
| info="-1 表示随机种子" | |
| ) | |
| submit_btn = gr.Button("🚀 开始修复", variant="primary") | |
| # 添加队列状态提示 | |
| gr.Markdown(""" | |
| <small>⏳ 如果按钮显示"排队中...",说明有其他用户正在使用,请稍候。</small> | |
| """) | |
| with gr.Column(scale=1): | |
| output_image = gr.Image(type="pil", label="修复结果") | |
| # 绑定点击事件 | |
| submit_btn.click( | |
| fn=process_image, | |
| inputs=[input_image, resolution, steps, sampler, sigma, seed], | |
| outputs=output_image, | |
| show_progress="full" # 显示完整进度条 | |
| ) | |
| gr.Markdown(""" | |
| ### 📝 说明 | |
| **采样器类型**: | |
| - **Euler ODE**: 标准 Flow Matching 采样,确定性,速度快 | |
| - **Heun ODE**: 二阶 Runge-Kutta 方法,更准确但需要双倍计算 | |
| - **SDE Euler-Maruyama**: 添加随机噪声的 SDE 采样,可以增加多样性 | |
| - **SDE Reverse Diffusion**: 使用衰减噪声的逆向扩散 SDE | |
| **参数建议**: | |
| - 一般情况:Euler ODE, 25 步 | |
| - 更高质量:Heun ODE, 30-50 步 | |
| - 需要多样性/创意修复:SDE 采样器, sigma=0.1-0.2 | |
| """) | |
| return demo | |
| # ----------------------------------------------------------------------------- | |
| # 主函数 | |
| # ----------------------------------------------------------------------------- | |
| def main(): | |
| restorer = ImageRestorer(model_path=MODEL_PATH,repo_id=MODEL_REPO_ID) | |
| create_demo(restorer).queue( | |
| max_size=10, | |
| default_concurrency_limit=2 | |
| ).launch() | |
| if __name__ == "__main__": | |
| main() |