ImageRestore / app.py
telecomadm1145's picture
Update app.py
fd59366 verified
"""
Gradio Demo for FLUX VAE Image Restoration Model
支持自定义分辨率、采样器(Euler ODE / SDE Euler-Maruyama)和推理步数
"""
import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import gradio as gr
from PIL import Image
import timm
from diffusers import AutoencoderKL
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download
# -----------------------------------------------------------------------------
# 配置
# -----------------------------------------------------------------------------
MODEL_PATH = "model.safetensors" # 模型权重路径
MODEL_REPO_ID = "telecomadm1145/img_restore" # 替换为你的 HF 仓库 ID
MODEL_FILENAME = "model.safetensors" # HF 仓库中的文件名
VAE_ID = "advokat/AnimePro-FLUX" #"black-forest-labs/FLUX.1-schnell"
VAE_SUBFOLDER = "vae"
HF_TOKEN = os.environ.get("HF_TOKEN") or None
# 模型参数(与训练时一致)
LATENT_CHANNELS = 16
VAE_SCALE_FACTOR = 8
DIT_HIDDEN_SIZE = 1024
DIT_DEPTH = 16
DIT_NUM_HEADS = 4
PATCH_SIZE = 2
DINO_MODEL_NAME = 'vit_base_patch16_dinov3.lvd1689m'
IMG_SIZE = 384 # 默认训练尺寸
# VAE 统计量(如果有缓存的话加载,否则使用默认值)
DEFAULT_VAE_MEAN = torch.zeros(LATENT_CHANNELS)
DEFAULT_VAE_STD = torch.ones(LATENT_CHANNELS)
# -----------------------------------------------------------------------------
# 模型定义(与训练代码完全一致)
# -----------------------------------------------------------------------------
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
class TimestepEmbedder(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
t = t * 1000.0
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(0, half, dtype=torch.float32, device=t.device) / half)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
return self.mlp(t_freq)
class RotaryEmbedding2D(nn.Module):
def __init__(self, dim, max_h=64, max_w=64):
super().__init__()
self.dim = dim
dim_h = dim // 2
dim_w = dim - dim_h
inv_freq_h = 1.0 / (10000 ** (torch.arange(0, dim_h, 2).float() / dim_h))
inv_freq_w = 1.0 / (10000 ** (torch.arange(0, dim_w, 2).float() / dim_w))
self.register_buffer("inv_freq_h", inv_freq_h)
self.register_buffer("inv_freq_w", inv_freq_w)
self._set_cos_sin_cache(max_h, max_w)
def _set_cos_sin_cache(self, h, w):
t_h = torch.arange(h).type_as(self.inv_freq_h)
freqs_h = torch.outer(t_h, self.inv_freq_h)
emb_h = torch.cat((freqs_h, freqs_h), dim=-1)
t_w = torch.arange(w).type_as(self.inv_freq_w)
freqs_w = torch.outer(t_w, self.inv_freq_w)
emb_w = torch.cat((freqs_w, freqs_w), dim=-1)
emb_h_broad = emb_h.unsqueeze(1).repeat(1, w, 1)
emb_w_broad = emb_w.unsqueeze(0).repeat(h, 1, 1)
emb = torch.cat((emb_h_broad, emb_w_broad), dim=-1).flatten(0, 1)
self.register_buffer("cos_cached", emb.cos().unsqueeze(0).unsqueeze(0), persistent=False)
self.register_buffer("sin_cached", emb.sin().unsqueeze(0).unsqueeze(0), persistent=False)
def forward(self, x, h, w):
return self.cos_cached[:, :, : h * w, :].to(x.dtype), self.sin_cached[:, :, : h * w, :].to(x.dtype)
def apply_rotary_pos_emb(q, k, cos, sin):
def rotate_half(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
class SwiGLU(nn.Module):
def __init__(self, hidden_size: int, mlp_ratio: float = 4.0):
super().__init__()
mlp_hidden = int(hidden_size * mlp_ratio * 2 / 3)
mlp_hidden = ((mlp_hidden + 63) // 64) * 64
self.w1 = nn.Linear(hidden_size, mlp_hidden, bias=False)
self.w2 = nn.Linear(hidden_size, mlp_hidden, bias=False)
self.w3 = nn.Linear(mlp_hidden, hidden_size, bias=False)
def forward(self, x):
return self.w3(F.silu(self.w1(x)) * self.w2(x))
class DiTBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float = 4.0):
super().__init__()
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.norm1_latent = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.norm1_cond = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.qkv = nn.Linear(hidden_size, hidden_size * 3, bias=False)
self.proj = nn.Linear(hidden_size, hidden_size, bias=False)
self.norm2_latent = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.norm2_cond = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.q_norm = nn.LayerNorm(self.head_dim, eps=1e-6)
self.k_norm = nn.LayerNorm(self.head_dim, eps=1e-6)
self.mlp = SwiGLU(hidden_size, mlp_ratio)
self.adaLN_latent = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=False)
)
self.adaLN_cond = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=False)
)
def forward(self, x, t_emb, rope_cos, rope_sin, num_latents):
B, L, D = x.shape
num_cond = L - num_latents
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = \
self.adaLN_latent(t_emb).chunk(6, dim=-1)
shift_msa_c, scale_msa_c, gate_msa_c, shift_mlp_c, scale_mlp_c, gate_mlp_c = \
self.adaLN_cond(t_emb).chunk(6, dim=-1)
x_lat, x_cond = x[:, :num_latents], x[:, num_latents:]
x_lat_norm = modulate(self.norm1_latent(x_lat), shift_msa, scale_msa)
x_cond_norm = modulate(self.norm1_cond(x_cond), shift_msa_c, scale_msa_c)
x_norm = torch.cat([x_lat_norm, x_cond_norm], dim=1)
qkv = self.qkv(x_norm)
q, k, v = qkv.reshape(B, L, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4).unbind(0)
q = self.q_norm(q)
k = self.k_norm(k)
q, k = apply_rotary_pos_emb(q, k, rope_cos, rope_sin)
q, k = q.to(v.dtype), k.to(v.dtype)
x_attn = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0)
x_attn = x_attn.transpose(1, 2).reshape(B, L, D)
x_attn = self.proj(x_attn)
x_attn_lat, x_attn_cond = x_attn[:, :num_latents], x_attn[:, num_latents:]
x_lat = x_lat + gate_msa.unsqueeze(1) * x_attn_lat
x_cond = x_cond + gate_msa_c.unsqueeze(1) * x_attn_cond
x_lat_norm = modulate(self.norm2_latent(x_lat), shift_mlp, scale_mlp)
x_cond_norm = modulate(self.norm2_cond(x_cond), shift_mlp_c, scale_mlp_c)
x_norm = torch.cat([x_lat_norm, x_cond_norm], dim=1)
mlp_out = self.mlp(x_norm)
mlp_lat, mlp_cond = mlp_out[:, :num_latents], mlp_out[:, num_latents:]
x_lat = x_lat + gate_mlp.unsqueeze(1) * mlp_lat
x_cond = x_cond + gate_mlp_c.unsqueeze(1) * mlp_cond
return torch.cat([x_lat, x_cond], dim=1)
class FinalLayer(nn.Module):
def __init__(self, hidden_size, patch_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift, scale)
return self.linear(x)
class FluxLatentDINOFlow(nn.Module):
def __init__(
self,
img_size=256,
patch_size=2,
latent_channels=4,
hidden_size=1152,
depth=28,
num_heads=16,
dino_model_name='vit_base_patch14_dinov2.lvd142m',
):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.latent_channels = latent_channels
self.hidden_size = hidden_size
self.num_heads = num_heads
self.latent_size = img_size // 8
self.grid_size = self.latent_size // patch_size
self.num_patches = self.grid_size ** 2
print(f"Loading DINO: {dino_model_name}")
self.dino = timm.create_model(dino_model_name, pretrained=True, img_size=img_size, num_classes=0)
for p in self.dino.parameters():
p.requires_grad = False
self.dino.eval()
self.dino_adapter = nn.Sequential(
nn.Conv2d(self.dino.embed_dim, hidden_size, kernel_size=1),
nn.SiLU(),
nn.Conv2d(hidden_size, hidden_size, kernel_size=3, padding=1),
)
self.pixel_adapter = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),
nn.SiLU(),
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
nn.SiLU(),
nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
nn.SiLU(),
nn.Conv2d(256, hidden_size, kernel_size=patch_size, stride=patch_size),
)
self.x_embedder = nn.Linear(patch_size * patch_size * latent_channels, hidden_size)
self.t_embedder = TimestepEmbedder(hidden_size)
self.rope = RotaryEmbedding2D(dim=hidden_size // num_heads, max_h=self.grid_size, max_w=self.grid_size)
self.blocks = nn.ModuleList([DiTBlock(hidden_size, num_heads) for _ in range(depth)])
self.final_layer = FinalLayer(hidden_size, patch_size, latent_channels)
self.type_emb_target = nn.Parameter(torch.randn(1, 1, hidden_size) * 0.02)
self.type_emb_pixel = nn.Parameter(torch.randn(1, 1, hidden_size) * 0.02)
self.type_emb_dino = nn.Parameter(torch.randn(1, 1, hidden_size) * 0.02)
self.initialize_weights()
# 新增缓存
self._cached_dino_map = None
self._cached_lq_hash = None # 可选:缓存输入哈希
def initialize_weights(self):
for name, m in self.named_modules():
if "dino" in name:
continue
if isinstance(m, (nn.Linear, nn.Conv2d)):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
nn.init.zeros_(self.final_layer.linear.weight)
nn.init.zeros_(self.final_layer.linear.bias)
def patchify(self, x):
p = self.patch_size
h, w = x.shape[2] // p, x.shape[3] // p
x = x.reshape(x.shape[0], x.shape[1], h, p, w, p)
x = torch.einsum('nchpwq->nhwpqc', x)
x = x.reshape(x.shape[0], h * w, -1)
return x
def unpatchify(self, x):
p = self.patch_size
c = self.latent_channels
h = w = int(x.shape[1] ** 0.5)
x = x.reshape(x.shape[0], h, w, p, p, c)
x = torch.einsum('nhwpqc->nchpwq', x)
return x.reshape(x.shape[0], c, h * p, w * p)
def forward(self, x_t_latent, t, lq_img):
B = x_t_latent.shape[0]
x_patches = self.patchify(x_t_latent)
x_tokens = self.x_embedder(x_patches)
x_tokens = x_tokens + self.type_emb_target
pixel_tokens = self.pixel_adapter(lq_img)
pixel_tokens = pixel_tokens.flatten(2).transpose(1, 2)
pixel_tokens = pixel_tokens + self.type_emb_pixel
# 计算输入 hash
lq_hash = hash(lq_img.data_ptr()) # 简单用指针做哈希,也可用 tensor.sum().item() 更精确
if self._cached_dino_map is None or self._cached_lq_hash != lq_hash:
print("recalculating hash...")
# 只在缓存不存在或输入变化时计算 DINO
with torch.no_grad():
mean = torch.tensor([0.485, 0.456, 0.406], device=lq_img.device).view(1, 3, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225], device=lq_img.device).view(1, 3, 1, 1)
dino_in = (lq_img * 0.5 + 0.5 - mean) / std
dino_feats = self.dino.forward_features(dino_in)
if getattr(self.dino, "num_prefix_tokens", 0) > 0:
dino_feats = dino_feats[:, self.dino.num_prefix_tokens:]
d_h = d_w = int(dino_feats.shape[1] ** 0.5)
dino_map = dino_feats.transpose(1, 2).reshape(B, -1, d_h, d_w)
dino_map_resized = F.interpolate(dino_map, size=(self.grid_size, self.grid_size), mode='bilinear', align_corners=False)
dino_tokens = self.dino_adapter(dino_map_resized)
dino_tokens = dino_tokens.flatten(2).transpose(1, 2)
dino_tokens = dino_tokens + self.type_emb_dino
# 更新缓存
self._cached_dino_map = dino_tokens
self._cached_lq_hash = lq_hash
else:
dino_tokens = self._cached_dino_map
tokens = torch.cat([x_tokens, pixel_tokens, dino_tokens], dim=1)
t_emb = self.t_embedder(t)
cos_base, sin_base = self.rope(tokens, self.grid_size, self.grid_size)
cos = torch.cat([cos_base] * 3, dim=2)
sin = torch.cat([sin_base] * 3, dim=2)
for block in self.blocks:
tokens = block(tokens, t_emb, cos, sin, num_latents=self.num_patches)
out_tokens = tokens[:, :self.num_patches]
out_patches = self.final_layer(out_tokens, t_emb)
out_latents = self.unpatchify(out_patches)
return out_latents
# -----------------------------------------------------------------------------
# VAE 管理器
# -----------------------------------------------------------------------------
class VAEManager:
def __init__(self, model_id, subfolder, device, mean=None, std=None):
print(f"Loading Flux VAE from {model_id}...")
self.vae = AutoencoderKL.from_pretrained(model_id, subfolder=subfolder, token=HF_TOKEN)
self.device = device
self.vae.to(self.device).eval()
self.vae.requires_grad_(False)
if mean is None:
mean = DEFAULT_VAE_MEAN
if std is None:
std = DEFAULT_VAE_STD
self.register_stats(mean, std)
def register_stats(self, mean, std):
self.shift = mean.to(self.device).view(1, -1, 1, 1)
self.scale = (1.0 / (std.to(self.device) + 1e-6)).view(1, -1, 1, 1)
print(f"VAE Stats Registered")
@torch.no_grad()
def encode(self, pixels):
latents = self.vae.encode(pixels).latent_dist.mode()
latents = (latents - self.shift) * self.scale
return latents
@torch.no_grad()
def decode(self, latents):
latents = latents / self.scale + self.shift
return self.vae.decode(latents).sample
# -----------------------------------------------------------------------------
# 采样器(添加进度回调支持)
# -----------------------------------------------------------------------------
class FlowMatchingSampler:
"""Flow Matching 采样器,支持 ODE 和 SDE"""
def __init__(self, model, vae_mgr, device):
self.model = model
self.vae_mgr = vae_mgr
self.device = device
@torch.no_grad()
def sample_euler_ode(self, lq, steps, progress_callback=None):
"""Euler ODE 采样器(确定性)"""
B = lq.shape[0]
H_lat = lq.shape[2] // VAE_SCALE_FACTOR
W_lat = lq.shape[3] // VAE_SCALE_FACTOR
x = torch.randn(B, LATENT_CHANNELS, H_lat, W_lat, device=self.device)
dt = 1.0 / steps
for i in range(steps):
t = torch.full((B,), i / steps, device=self.device, dtype=torch.float32)
with torch.cuda.amp.autocast(dtype=torch.float16):
v = self.model(x, t, lq)
x = x + v * dt
# 进度回调
if progress_callback is not None:
progress_callback(i + 1, steps, f"Euler ODE 采样中... {i+1}/{steps}")
restored = self.vae_mgr.decode(x)
return torch.clamp(restored, -1, 1)
@torch.no_grad()
def sample_sde_euler_maruyama(self, lq, steps, sigma=0.1, progress_callback=None):
"""SDE Euler-Maruyama 采样器(随机性)"""
B = lq.shape[0]
H_lat = lq.shape[2] // VAE_SCALE_FACTOR
W_lat = lq.shape[3] // VAE_SCALE_FACTOR
x = torch.randn(B, LATENT_CHANNELS, H_lat, W_lat, device=self.device)
dt = 1.0 / steps
sqrt_dt = math.sqrt(dt)
for i in range(steps):
t = torch.full((B,), i / steps, device=self.device, dtype=torch.float32)
with torch.cuda.amp.autocast(dtype=torch.float16):
v = self.model(x, t, lq)
noise = torch.randn_like(x)
x = x + v * dt + sigma * sqrt_dt * noise
# 进度回调
if progress_callback is not None:
progress_callback(i + 1, steps, f"SDE Euler-Maruyama 采样中... {i+1}/{steps}")
restored = self.vae_mgr.decode(x)
return torch.clamp(restored, -1, 1)
@torch.no_grad()
def sample_sde_reverse_diffusion(self, lq, steps, sigma_schedule="linear", progress_callback=None):
"""逆向 SDE 采样器"""
B = lq.shape[0]
H_lat = lq.shape[2] // VAE_SCALE_FACTOR
W_lat = lq.shape[3] // VAE_SCALE_FACTOR
x = torch.randn(B, LATENT_CHANNELS, H_lat, W_lat, device=self.device)
dt = 1.0 / steps
for i in range(steps):
t_val = i / steps
t = torch.full((B,), t_val, device=self.device, dtype=torch.float32)
with torch.cuda.amp.autocast(dtype=torch.float16):
v = self.model(x, t, lq)
if sigma_schedule == "linear":
sigma = 0.5 * (1 - t_val)
elif sigma_schedule == "cosine":
sigma = 0.5 * math.cos(t_val * math.pi / 2)
else:
sigma = 0.1
noise = torch.randn_like(x) if i < steps - 1 else 0
x = x + v * dt + sigma * math.sqrt(dt) * noise
# 进度回调
if progress_callback is not None:
progress_callback(i + 1, steps, f"SDE Reverse Diffusion 采样中... {i+1}/{steps}")
restored = self.vae_mgr.decode(x)
return torch.clamp(restored, -1, 1)
@torch.no_grad()
def sample_heun_ode(self, lq, steps, progress_callback=None):
"""Heun's Method (二阶 Runge-Kutta) ODE 采样器"""
B = lq.shape[0]
H_lat = lq.shape[2] // VAE_SCALE_FACTOR
W_lat = lq.shape[3] // VAE_SCALE_FACTOR
x = torch.randn(B, LATENT_CHANNELS, H_lat, W_lat, device=self.device)
dt = 1.0 / steps
for i in range(steps):
t = torch.full((B,), i / steps, device=self.device, dtype=torch.float32)
t_next = torch.full((B,), (i + 1) / steps, device=self.device, dtype=torch.float32)
with torch.cuda.amp.autocast(dtype=torch.float16):
v1 = self.model(x, t, lq)
x_pred = x + v1 * dt
if i < steps - 1:
v2 = self.model(x_pred, t_next, lq)
x = x + 0.5 * (v1 + v2) * dt
else:
x = x_pred
# 进度回调
if progress_callback is not None:
progress_callback(i + 1, steps, f"Heun ODE 采样中... {i+1}/{steps}")
restored = self.vae_mgr.decode(x)
return torch.clamp(restored, -1, 1)
def sample(self, lq, steps, sampler_type="euler_ode", progress_callback=None, **kwargs):
"""统一采样接口"""
if sampler_type == "euler_ode":
return self.sample_euler_ode(lq, steps, progress_callback=progress_callback)
elif sampler_type == "sde_euler_maruyama":
sigma = kwargs.get("sigma", 0.1)
return self.sample_sde_euler_maruyama(lq, steps, sigma=sigma, progress_callback=progress_callback)
elif sampler_type == "sde_reverse":
sigma_schedule = kwargs.get("sigma_schedule", "linear")
return self.sample_sde_reverse_diffusion(lq, steps, sigma_schedule=sigma_schedule, progress_callback=progress_callback)
elif sampler_type == "heun_ode":
return self.sample_heun_ode(lq, steps, progress_callback=progress_callback)
else:
raise ValueError(f"Unknown sampler type: {sampler_type}")
# -----------------------------------------------------------------------------
# 模型加载
# -----------------------------------------------------------------------------
class ImageRestorer:
def __init__(self, model_path=None, device="cuda",
repo_id=None, filename="model.safetensors"):
"""
Args:
model_path: 本地模型路径,如果为 None 则从 HF 下载
device: 运行设备
repo_id: Hugging Face 仓库 ID,例如 "username/model-name"
filename: HF 仓库中的模型文件名
"""
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
print(f"Using device: {self.device}")
# ========== 从 Hugging Face 下载模型 ==========
if model_path is None or not os.path.exists(model_path):
if repo_id is not None:
print(f"Downloading model from Hugging Face: {repo_id}/{filename}")
try:
model_path = hf_hub_download(
repo_id=repo_id,
filename=filename,
token=HF_TOKEN,
cache_dir="./hf_cache" # 可选:指定缓存目录
)
print(f"Model downloaded to: {model_path}")
except Exception as e:
raise RuntimeError(f"Failed to download model from HF: {e}")
else:
raise FileNotFoundError(
f"Model not found at {model_path} and no repo_id provided"
)
# ========== 同时下载 VAE 统计量(如果有的话)==========
vae_mean = DEFAULT_VAE_MEAN
vae_std = DEFAULT_VAE_STD
# 先尝试从 HF 下载 vae_stats.npy
if repo_id is not None:
try:
vae_stats_path = hf_hub_download(
repo_id=repo_id,
filename="vae_stats.npy",
token=HF_TOKEN,
cache_dir="./hf_cache"
)
stats = np.load(vae_stats_path, allow_pickle=True).item()
vae_mean = torch.from_numpy(stats['mean'])
vae_std = torch.from_numpy(stats['std'])
print("Loaded VAE stats from HF repo")
except Exception:
print("No vae_stats.npy in HF repo, checking local...")
# 尝试本地 vae_stats.npy
if os.path.exists("vae_stats.npy"):
try:
stats = np.load("vae_stats.npy", allow_pickle=True).item()
vae_mean = torch.from_numpy(stats['mean'])
vae_std = torch.from_numpy(stats['std'])
print("Loaded cached VAE stats from local")
except Exception as e:
print(f"Failed to load local VAE stats: {e}")
# 加载 VAE
self.vae_mgr = VAEManager(VAE_ID, VAE_SUBFOLDER, self.device, vae_mean, vae_std)
# 加载模型
print(f"Loading model from {model_path}...")
self.model = FluxLatentDINOFlow(
hidden_size=DIT_HIDDEN_SIZE,
depth=DIT_DEPTH,
num_heads=DIT_NUM_HEADS,
patch_size=PATCH_SIZE,
latent_channels=LATENT_CHANNELS,
img_size=IMG_SIZE,
dino_model_name=DINO_MODEL_NAME
).to(self.device)
state_dict = load_file(model_path)
self.model.load_state_dict(state_dict, strict=False)
print("Model loaded successfully")
self.model.eval()
# 创建采样器
self.sampler = FlowMatchingSampler(self.model, self.vae_mgr, self.device)
def preprocess(self, image: Image.Image, target_size: int) -> torch.Tensor:
"""预处理图像"""
# 确保尺寸是 VAE_SCALE_FACTOR * PATCH_SIZE 的倍数
min_unit = VAE_SCALE_FACTOR * PATCH_SIZE
target_size = (target_size // min_unit) * min_unit
if target_size < min_unit:
target_size = min_unit
# Resize
image = image.convert("RGB")
image = image.resize((target_size, target_size), Image.Resampling.BICUBIC)
# To tensor
arr = np.array(image).astype(np.float32) / 127.5 - 1.0
tensor = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0)
return tensor.to(self.device)
def postprocess(self, tensor: torch.Tensor) -> Image.Image:
"""后处理张量为图像"""
tensor = tensor.squeeze(0).cpu()
tensor = (tensor * 0.5 + 0.5).clamp(0, 1)
arr = (tensor.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
return Image.fromarray(arr)
@torch.no_grad()
def restore(
self,
image: Image.Image,
resolution: int = 384,
steps: int = 25,
sampler_type: str = "euler_ode",
sigma: float = 0.1,
seed: int = -1,
progress_callback=None # 添加进度回调参数
) -> Image.Image:
"""执行图像修复"""
# 设置随机种子
if seed >= 0:
torch.manual_seed(seed)
np.random.seed(seed)
# 预处理
if progress_callback:
progress_callback(0, steps + 2, "预处理图像...")
lq = self.preprocess(image, resolution)
# 动态调整模型参数以适应不同分辨率
self._adjust_model_for_resolution(resolution)
if progress_callback:
progress_callback(1, steps + 2, "开始采样...")
# 包装进度回调,调整偏移
def wrapped_progress(current, total, desc):
if progress_callback:
progress_callback(current + 1, steps + 2, desc)
# 采样
restored_tensor = self.sampler.sample(
lq, steps,
sampler_type=sampler_type,
sigma=sigma,
progress_callback=wrapped_progress
)
if progress_callback:
progress_callback(steps + 2, steps + 2, "后处理...")
# 后处理
return self.postprocess(restored_tensor)
def _adjust_model_for_resolution(self, resolution: int):
"""动态调整模型以适应不同分辨率"""
min_unit = VAE_SCALE_FACTOR * PATCH_SIZE
resolution = (resolution // min_unit) * min_unit
new_latent_size = resolution // VAE_SCALE_FACTOR
new_grid_size = new_latent_size // PATCH_SIZE
if new_grid_size != self.model.grid_size:
print(f"Adjusting model for resolution {resolution} (grid: {new_grid_size})")
self.model.latent_size = new_latent_size
self.model.grid_size = new_grid_size
self.model.num_patches = new_grid_size ** 2
# 重新计算 RoPE
self.model.rope._set_cos_sin_cache(new_grid_size, new_grid_size)
def create_demo(restorer: ImageRestorer):
"""创建 Gradio Demo"""
def process_image(
image,
resolution,
steps,
sampler,
sigma,
seed,
progress=gr.Progress(track_tqdm=True) # 添加 progress 参数
):
if image is None:
return None
# 采样器映射
sampler_map = {
"Euler ODE (确定性)": "euler_ode",
"Heun ODE (二阶,更准确)": "heun_ode",
"SDE Euler-Maruyama (随机性)": "sde_euler_maruyama",
"SDE Reverse Diffusion (逆向扩散)": "sde_reverse",
}
sampler_type = sampler_map.get(sampler, "euler_ode")
# 创建进度回调
def progress_callback(current, total, desc):
progress(current / total, desc=desc)
try:
result = restorer.restore(
image,
resolution=int(resolution),
steps=int(steps),
sampler_type=sampler_type,
sigma=float(sigma),
seed=int(seed),
progress_callback=progress_callback # 传入进度回调
)
return result
except Exception as e:
print(f"Error: {e}")
import traceback
traceback.print_exc()
return None
with gr.Blocks(title="Image Restoration Demo", css="""
.progress-bar {
height: 20px !important;
}
""") as demo:
gr.Markdown("""
# 🖼️ FLUX VAE 图像修复 Demo
使用 Flow Matching + DiT + DINO 的图像修复模型。上传一张图像,选择参数后点击"修复"按钮。
> 💡 提示:进度条会显示当前处理步骤,如果有多人同时使用会显示排队状态。
""")
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(type="pil", label="输入图像")
with gr.Group():
resolution = gr.Slider(
minimum=128,
maximum=1024,
value=384,
step=16,
label="分辨率",
info="图像会被 resize 到此分辨率(会自动对齐到 16 的倍数)"
)
steps = gr.Slider(
minimum=5,
maximum=100,
value=25,
step=1,
label="推理步数",
info="更多步数 = 更好质量,但更慢"
)
sampler = gr.Dropdown(
choices=[
"Euler ODE (确定性)",
"Heun ODE (二阶,更准确)",
"SDE Euler-Maruyama (随机性)",
"SDE Reverse Diffusion (逆向扩散)",
],
value="Euler ODE (确定性)",
label="采样器",
info="ODE 是确定性的,SDE 会添加随机噪声"
)
sigma = gr.Slider(
minimum=0.0,
maximum=1.0,
value=0.1,
step=0.01,
label="SDE 噪声强度 (sigma)",
info="仅对 SDE 采样器有效,越大随机性越强"
)
seed = gr.Number(
value=-1,
label="随机种子",
info="-1 表示随机种子"
)
submit_btn = gr.Button("🚀 开始修复", variant="primary")
# 添加队列状态提示
gr.Markdown("""
<small>⏳ 如果按钮显示"排队中...",说明有其他用户正在使用,请稍候。</small>
""")
with gr.Column(scale=1):
output_image = gr.Image(type="pil", label="修复结果")
# 绑定点击事件
submit_btn.click(
fn=process_image,
inputs=[input_image, resolution, steps, sampler, sigma, seed],
outputs=output_image,
show_progress="full" # 显示完整进度条
)
gr.Markdown("""
### 📝 说明
**采样器类型**:
- **Euler ODE**: 标准 Flow Matching 采样,确定性,速度快
- **Heun ODE**: 二阶 Runge-Kutta 方法,更准确但需要双倍计算
- **SDE Euler-Maruyama**: 添加随机噪声的 SDE 采样,可以增加多样性
- **SDE Reverse Diffusion**: 使用衰减噪声的逆向扩散 SDE
**参数建议**:
- 一般情况:Euler ODE, 25 步
- 更高质量:Heun ODE, 30-50 步
- 需要多样性/创意修复:SDE 采样器, sigma=0.1-0.2
""")
return demo
# -----------------------------------------------------------------------------
# 主函数
# -----------------------------------------------------------------------------
def main():
restorer = ImageRestorer(model_path=MODEL_PATH,repo_id=MODEL_REPO_ID)
create_demo(restorer).queue(
max_size=10,
default_concurrency_limit=2
).launch()
if __name__ == "__main__":
main()