nghyane/train-lama / scripts /export_lama.py
nghyane's picture
download
raw
20.8 kB
"""Export LaMa-manga from safetensors → ONNX with native DFT ops.
Torch 2.9+ exports torch.fft.rfft2/irfft2 as ONNX DFT operators,
producing a compact graph with dynamic spatial dimensions.
Requirements:
pip install torch>=2.9 safetensors huggingface_hub onnx onnxscript
Usage:
python scripts/export_lama.py -o models/lama-manga.onnx
"""
import argparse
import os
import logging
from collections import Counter, OrderedDict
logging.basicConfig(level=logging.INFO)
import torch
import torch.nn as nn
# ── LaMa architecture (FFCResNetGenerator) ──────────────────────────────────
class FourierUnit(nn.Module):
"""FourierUnit using native torch.fft (exports as ONNX DFT ops)."""
def __init__(self, in_channels, out_channels, groups=1, **kwargs):
super().__init__()
self.groups = groups
self.conv_layer = nn.Conv2d(in_channels * 2, out_channels * 2,
kernel_size=1, stride=1, padding=0,
groups=groups, bias=False)
self.bn = nn.BatchNorm2d(out_channels * 2)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
batch, channels, h, w = x.shape
spec = torch.fft.rfft2(x, norm='backward')
real = spec.real # [B, C, H, W_half]
imag = spec.imag
w_half = w // 2 + 1
# Interleave real/imag as channels without 5D tensors:
# [B, C, 1, H*W_half] + cat(dim=2) → [B, C, 2, H*W_half] → [B, 2C, H, W_half]
real_flat = real.reshape(batch, channels, 1, h * w_half)
imag_flat = imag.reshape(batch, channels, 1, h * w_half)
ffted = torch.cat([real_flat, imag_flat], dim=2).reshape(batch, channels * 2, h, w_half)
ffted = self.relu(self.bn(self.conv_layer(ffted)))
# De-interleave back: [B, 2C, H, W_half] → [B, C, 2, H*W_half] → split
out_c = ffted.shape[1] // 2
ffted = ffted.reshape(batch, out_c, 2, h * w_half)
out_r = ffted[:, :, 0, :].reshape(batch, out_c, h, w_half)
out_i = ffted[:, :, 1, :].reshape(batch, out_c, h, w_half)
spec_out = torch.complex(out_r, out_i)
return torch.fft.irfft2(spec_out, s=(h, w), norm='backward')
class SpectralTransform(nn.Module):
def __init__(self, in_channels, out_channels, stride=1, groups=1,
enable_lfu=True, **kwargs):
super().__init__()
self.enable_lfu = enable_lfu
self.downsample = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels, out_channels // 2, 1, groups=groups, bias=False),
nn.BatchNorm2d(out_channels // 2),
nn.ReLU(inplace=True),
)
self.fu = FourierUnit(out_channels // 2, out_channels // 2, groups)
if enable_lfu:
self.lfu = FourierUnit(out_channels // 2, out_channels // 2, groups)
self.conv2 = nn.Conv2d(out_channels // 2, out_channels, 1,
groups=groups, bias=False)
def forward(self, x):
x = self.downsample(x)
x = self.conv1(x)
output = self.fu(x)
if self.enable_lfu:
n, c, h, w = x.shape
split_s = h // 2
xs = torch.cat(torch.split(x[:, :c // 4], split_s, dim=-2), dim=1).contiguous()
xs = torch.cat(torch.split(xs, split_s, dim=-1), dim=1).contiguous()
xs = self.lfu(xs)
xs = xs.repeat(1, 1, 2, 2).contiguous()
else:
xs = 0
return self.conv2(x + output + xs)
class FFC(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size,
ratio_gin, ratio_gout, stride=1, padding=0,
dilation=1, groups=1, bias=False, enable_lfu=True,
padding_type='reflect', **spectral_kwargs):
super().__init__()
self.ratio_gin = ratio_gin
self.ratio_gout = ratio_gout
in_cg = int(in_channels * ratio_gin)
in_cl = in_channels - in_cg
out_cg = int(out_channels * ratio_gout)
out_cl = out_channels - out_cg
module = nn.Identity if in_cl == 0 or out_cl == 0 else nn.Conv2d
self.convl2l = module(in_cl, out_cl, kernel_size, stride, padding,
dilation, groups, bias, padding_mode=padding_type)
module = nn.Identity if in_cl == 0 or out_cg == 0 else nn.Conv2d
self.convl2g = module(in_cl, out_cg, kernel_size, stride, padding,
dilation, groups, bias, padding_mode=padding_type)
module = nn.Identity if in_cg == 0 or out_cl == 0 else nn.Conv2d
self.convg2l = module(in_cg, out_cl, kernel_size, stride, padding,
dilation, groups, bias, padding_mode=padding_type)
module = nn.Identity if in_cg == 0 or out_cg == 0 else SpectralTransform
self.convg2g = module(in_cg, out_cg, stride, 1 if groups == 1 else groups // 2,
enable_lfu, **spectral_kwargs)
def forward(self, x):
# ratio_gin/ratio_gout are Python floats — tracer follows the correct branch.
if self.ratio_gin == 0:
x_l = x if not isinstance(x, tuple) else x[0]
out_xl = self.convl2l(x_l) if self.ratio_gout != 1 else None
out_xg = self.convl2g(x_l) if self.ratio_gout != 0 else None
else:
x_l, x_g = x
out_xl = (self.convl2l(x_l) + self.convg2l(x_g)) if self.ratio_gout != 1 else None
out_xg = (self.convl2g(x_l) + self.convg2g(x_g)) if self.ratio_gout != 0 else None
return out_xl, out_xg
class FFC_BN_ACT(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size,
ratio_gin, ratio_gout, stride=1, padding=0,
dilation=1, groups=1, bias=False,
norm_layer=nn.BatchNorm2d, activation_layer=nn.Identity,
padding_type='reflect', enable_lfu=True, **kwargs):
super().__init__()
self.ffc = FFC(in_channels, out_channels, kernel_size,
ratio_gin, ratio_gout, stride, padding, dilation,
groups, bias, enable_lfu, padding_type, **kwargs)
global_channels = int(out_channels * ratio_gout)
self.bn_l = nn.Identity() if ratio_gout == 1 else norm_layer(out_channels - global_channels)
self.bn_g = nn.Identity() if ratio_gout == 0 else norm_layer(global_channels)
self.act_l = nn.Identity() if ratio_gout == 1 else activation_layer(inplace=True)
self.act_g = nn.Identity() if ratio_gout == 0 else activation_layer(inplace=True)
def forward(self, x):
x_l, x_g = self.ffc(x)
out_l = self.act_l(self.bn_l(x_l)) if x_l is not None else None
out_g = self.act_g(self.bn_g(x_g)) if x_g is not None else None
return out_l, out_g
class FFCResnetBlock(nn.Module):
def __init__(self, dim, padding_type, norm_layer, activation_layer=nn.ReLU,
dilation=1, inline=False, **conv_kwargs):
super().__init__()
self.conv1 = FFC_BN_ACT(dim, dim, 3, padding=dilation, dilation=dilation,
norm_layer=norm_layer, activation_layer=activation_layer,
padding_type=padding_type, **conv_kwargs)
self.conv2 = FFC_BN_ACT(dim, dim, 3, padding=dilation, dilation=dilation,
norm_layer=norm_layer, activation_layer=activation_layer,
padding_type=padding_type, **conv_kwargs)
self.inline = inline
def forward(self, x):
if isinstance(x, tuple):
x_l, x_g = x
else:
x_l, x_g = x, None
id_l, id_g = x_l, x_g
x_l, x_g = self.conv1((x_l, x_g))
x_l, x_g = self.conv2((x_l, x_g))
out_l = id_l + x_l if x_l is not None else None
out_g = id_g + x_g if x_g is not None else None
return out_l, out_g
class ConcatTupleLayer(nn.Module):
def forward(self, x):
x_l, x_g = x
if x_g is None:
return x_l
return torch.cat([x_l, x_g], dim=1)
class LaMaModel(nn.Module):
"""LaMa FFCResNetGenerator with separate image/mask inputs."""
def __init__(self):
super().__init__()
ngf = 64
n_downsampling = 3
n_blocks = 18
norm_layer = nn.BatchNorm2d
resnet_kw = dict(ratio_gin=0.75, ratio_gout=0.75, enable_lfu=False)
init_kw = dict(ratio_gin=0, ratio_gout=0, enable_lfu=False)
down_kw = dict(ratio_gin=0, ratio_gout=0, enable_lfu=False)
layers = [
nn.ReflectionPad2d(3),
FFC_BN_ACT(4, ngf, 7, padding=0, norm_layer=norm_layer,
activation_layer=nn.ReLU, **init_kw),
]
# Downsample
for i in range(n_downsampling):
mult = 2 ** i
kw = dict(down_kw)
if i == n_downsampling - 1:
kw['ratio_gout'] = resnet_kw['ratio_gin']
layers.append(FFC_BN_ACT(
min(1024, ngf * mult), min(1024, ngf * mult * 2),
3, stride=2, padding=1,
norm_layer=norm_layer, activation_layer=nn.ReLU, **kw))
# ResNet blocks
mult = 2 ** n_downsampling
feats = min(1024, ngf * mult)
for _ in range(n_blocks):
layers.append(FFCResnetBlock(feats, padding_type='reflect',
norm_layer=norm_layer,
activation_layer=nn.ReLU,
**resnet_kw))
layers.append(ConcatTupleLayer())
# Upsample
for i in range(n_downsampling):
mult = 2 ** (n_downsampling - i)
layers += [
nn.ConvTranspose2d(min(1024, ngf * mult),
min(1024, ngf * mult // 2),
3, stride=2, padding=1, output_padding=1),
norm_layer(min(1024, ngf * mult // 2)),
nn.ReLU(True),
]
layers += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, 3, 7), nn.Sigmoid()]
self.generator = nn.Sequential(*layers)
def forward(self, image, mask):
masked_img = image * (1.0 - mask)
x = torch.cat([masked_img, mask], dim=1)
return self.generator(x)
def load_weights(model, sf_path):
"""Load safetensors weights, mapping key prefixes.
Safetensors keys: model.N.xxx (nn.Sequential index)
Our model keys: generator.N.xxx
"""
from safetensors.torch import load_file
state = load_file(sf_path)
mapped = OrderedDict()
for k, v in state.items():
# model.N.xxx → generator.N.xxx
nk = k.replace("model.", "generator.", 1)
mapped[nk] = v
result = model.load_state_dict(mapped, strict=False)
if result.missing_keys:
real_missing = [k for k in result.missing_keys if 'num_batches_tracked' not in k]
if real_missing:
print(f"WARNING: {len(real_missing)} genuinely missing keys:")
for k in real_missing[:20]:
print(f" {k}")
if result.unexpected_keys:
print(f"NOTE: {len(result.unexpected_keys)} unexpected keys (ignored)")
loaded = len(state) - len(result.unexpected_keys)
print(f"Loaded {loaded}/{len(state)} weights")
return model
def resolve_safetensors_path(local_path):
if local_path:
return local_path
from huggingface_hub import hf_hub_download
print("Downloading safetensors from mayocream/lama-manga...")
return hf_hub_download("mayocream/lama-manga", "lama-manga.safetensors")
def verify_coreml_session(model_path):
try:
import onnxruntime as ort
except Exception as e:
print(f"CoreML verify skipped: onnxruntime unavailable ({e})")
return
providers = ort.get_available_providers()
if "CoreMLExecutionProvider" not in providers:
print(f"CoreML verify skipped: CoreMLExecutionProvider not available ({providers})")
return
ml_opts = {
"ModelFormat": "MLProgram",
"MLComputeUnits": "CPUAndNeuralEngine",
}
sess = ort.InferenceSession(
model_path,
providers=[("CoreMLExecutionProvider", ml_opts), "CPUExecutionProvider"],
)
got_coreml = "CoreMLExecutionProvider" in sess.get_providers()
if got_coreml:
print("CoreML MLProgram session init: OK")
else:
raise RuntimeError(
f"CoreML MLProgram init fallback detected; active providers={sess.get_providers()}"
)
def main():
parser = argparse.ArgumentParser(description="Export LaMa-manga ONNX for CoreML/ANE")
parser.add_argument("-o", "--output", default="models/lama-manga.onnx")
parser.add_argument("--safetensors", default=None, help="Local safetensors path")
parser.add_argument("--verify", action="store_true", help="Run test inference")
parser.add_argument(
"--verify-coreml",
action="store_true",
help="Try creating an ONNX Runtime CoreML MLProgram session after export",
)
args = parser.parse_args()
sf_path = resolve_safetensors_path(args.safetensors)
# 1. Build model
print("Building LaMa model (rank-4-safe DFT)...")
model = LaMaModel()
# 2. Load weights
print(f"Loading weights from {sf_path}")
model = load_weights(model, sf_path)
model.eval()
# 3. Optional: verify inference
if args.verify:
with torch.no_grad():
img = torch.randn(1, 3, 512, 512)
mask = torch.zeros(1, 1, 512, 512)
mask[:, :, 100:200, 100:200] = 1.0
out = model(img, mask)
print(f"Test output: shape={out.shape}, range=[{out.min():.3f}, {out.max():.3f}]")
# 4. Export ONNX
print("Exporting ONNX...")
dummy_img = torch.randn(1, 3, 512, 512)
dummy_mask = torch.zeros(1, 1, 512, 512)
# LaMa has 3× stride-2 downsampling; input must be multiple of 8.
# Rust inpaint.rs already pads to multiple of 8.
h_base = torch.export.Dim("h_blocks", min=8, max=512)
w_base = torch.export.Dim("w_blocks", min=8, max=512)
height = h_base * 8
width = w_base * 8
dynamic_shapes = {
"image": {2: height, 3: width},
"mask": {2: height, 3: width},
}
with torch.no_grad():
torch.onnx.export(
model,
(dummy_img, dummy_mask),
args.output,
input_names=["image", "mask"],
output_names=["output"],
dynamic_shapes=dynamic_shapes,
)
# 5. Rewrite DFT(inverse=1, onesided=1) → hermitian pad + DFT(inverse=1, onesided=0).
# ORT does not support inverse+onesided combo. The decomposition is:
# 1. Take half-spectrum input [B,C,H,W//2+1,2]
# 2. Mirror bins [1:-1] with conjugate to reconstruct full spectrum
# 3. Run DFT(inverse=1, onesided=0) on the full spectrum
# This is the standard irfft decomposition, just done at graph level.
import onnx
from onnx import helper, TensorProto
model_onnx = onnx.load(args.output, load_external_data=True)
graph = model_onnx.graph
fixed = 0
nodes_to_add = []
nodes_to_remove = []
for node in list(graph.node):
if node.op_type != "DFT":
continue
attrs = {a.name: a.i for a in node.attribute}
if not (attrs.get("inverse", 0) == 1 and attrs.get("onesided", 0) == 1):
continue
inp = node.input[0] # half-spectrum [..., W//2+1, 2]
out = node.output[0]
uid = f"_irfft_{fixed}"
# Slice middle bins [1:-1] along axis -2
starts = helper.make_tensor(f"starts{uid}", TensorProto.INT64, [1], [1])
ends = helper.make_tensor(f"ends{uid}", TensorProto.INT64, [1], [-1])
axes = helper.make_tensor(f"axes{uid}", TensorProto.INT64, [1], [-2])
graph.initializer.extend([starts, ends, axes])
mid_name = f"mid{uid}"
nodes_to_add.append(helper.make_node(
"Slice", [inp, starts.name, ends.name, axes.name], [mid_name]))
# Flip middle bins along axis -2 via Slice with step=-1
flip_starts = helper.make_tensor(f"flip_s{uid}", TensorProto.INT64, [1], [-1])
flip_ends = helper.make_tensor(f"flip_e{uid}", TensorProto.INT64, [1], [-2147483648]) # INT_MIN
flip_axes = helper.make_tensor(f"flip_ax{uid}", TensorProto.INT64, [1], [-2])
flip_steps = helper.make_tensor(f"flip_st{uid}", TensorProto.INT64, [1], [-1])
graph.initializer.extend([flip_starts, flip_ends, flip_axes, flip_steps])
flip_name = f"flip{uid}"
nodes_to_add.append(helper.make_node(
"Slice", [mid_name, flip_starts.name, flip_ends.name, flip_axes.name, flip_steps.name],
[flip_name]))
# Conjugate = negate imag part: [..., 2] → flip sign of channel 1
conj_scale = helper.make_tensor(f"conj{uid}", TensorProto.FLOAT, [2], [1.0, -1.0])
graph.initializer.append(conj_scale)
conj_name = f"conj_out{uid}"
nodes_to_add.append(helper.make_node(
"Mul", [flip_name, conj_scale.name], [conj_name]))
# Concat: [half, conj_mirror] along axis -2
full_name = f"full{uid}"
nodes_to_add.append(helper.make_node(
"Concat", [inp, conj_name], [full_name], axis=-2))
# DFT(inverse=1, onesided=0) on full spectrum
new_dft = helper.make_node(
"DFT", [full_name], [out],
inverse=1, onesided=0, name=f"idft{uid}")
# Copy dft_length input if present
if len(node.input) > 1 and node.input[1]:
new_dft.input.append(node.input[1])
nodes_to_add.append(new_dft)
nodes_to_remove.append(node)
fixed += 1
for n in nodes_to_remove:
graph.node.remove(n)
graph.node.extend(nodes_to_add)
print(f"Rewrote {fixed} DFT(inverse+onesided) → hermitian pad + DFT(inverse)")
# 6. Rewrite Shape(start=..., end=...) to Shape + Slice.
# CoreML EP's MLProgram path currently rejects Shape with end == rank
# (e.g. start=3,end=4 on rank-4 tensors) with:
# "axis 4 not in valid range [-4,3]".
shape_rewrites = 0
rewritten_nodes = []
for node in list(graph.node):
if node.op_type != "Shape":
rewritten_nodes.append(node)
continue
attrs = {a.name: a.i for a in node.attribute}
if "start" not in attrs and "end" not in attrs:
rewritten_nodes.append(node)
continue
start = attrs.get("start", 0)
end = attrs.get("end", 9223372036854775807)
uid = f"_shapefix_{shape_rewrites}"
full_shape_out = f"{node.output[0]}{uid}_all"
shape_name = f"{node.name}_all" if node.name else f"shape_all{uid}"
slice_name = f"{node.name}_slice" if node.name else f"shape_slice{uid}"
rewritten_nodes.append(
helper.make_node("Shape", [node.input[0]], [full_shape_out], name=shape_name)
)
starts = helper.make_tensor(f"starts{uid}", TensorProto.INT64, [1], [start])
ends = helper.make_tensor(f"ends{uid}", TensorProto.INT64, [1], [end])
axes = helper.make_tensor(f"axes{uid}", TensorProto.INT64, [1], [0])
graph.initializer.extend([starts, ends, axes])
rewritten_nodes.append(
helper.make_node(
"Slice",
[full_shape_out, starts.name, ends.name, axes.name],
list(node.output),
name=slice_name,
)
)
shape_rewrites += 1
if shape_rewrites:
del graph.node[:]
graph.node.extend(rewritten_nodes)
print(f"Rewrote {shape_rewrites} Shape(start/end) nodes → Shape+Slice")
# Re-embed weights into single file
for tensor in graph.initializer:
tensor.ClearField("data_location")
onnx.save(model_onnx, args.output, save_as_external_data=False)
ops = Counter(n.op_type for n in model_onnx.graph.node)
print(f"\nFinal: {sum(ops.values())} nodes, {len(ops)} unique ops")
for op, count in ops.most_common():
print(f" {op}: {count}")
for t in list(model_onnx.graph.input) + list(model_onnx.graph.output):
shape = t.type.tensor_type.shape
dims = [d.dim_param or str(d.dim_value) for d in shape.dim]
print(f" {t.name}: [{', '.join(dims)}]")
size_mb = os.path.getsize(args.output) / (1024 * 1024)
print(f"\nSaved: {args.output} ({size_mb:.1f} MB)")
if args.verify_coreml:
verify_coreml_session(args.output)
if __name__ == "__main__":
main()

Xet Storage Details

Size:
20.8 kB
·
Xet hash:
88117aa924abf714288926d0be7e71ee690546bb72c30d06b0e66ab679cef1cb

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.