Buckets:
| """Export LaMa-manga from safetensors → ONNX with native DFT ops. | |
| Torch 2.9+ exports torch.fft.rfft2/irfft2 as ONNX DFT operators, | |
| producing a compact graph with dynamic spatial dimensions. | |
| Requirements: | |
| pip install torch>=2.9 safetensors huggingface_hub onnx onnxscript | |
| Usage: | |
| python scripts/export_lama.py -o models/lama-manga.onnx | |
| """ | |
| import argparse | |
| import os | |
| import logging | |
| from collections import Counter, OrderedDict | |
| logging.basicConfig(level=logging.INFO) | |
| import torch | |
| import torch.nn as nn | |
| # ── LaMa architecture (FFCResNetGenerator) ────────────────────────────────── | |
| class FourierUnit(nn.Module): | |
| """FourierUnit using native torch.fft (exports as ONNX DFT ops).""" | |
| def __init__(self, in_channels, out_channels, groups=1, **kwargs): | |
| super().__init__() | |
| self.groups = groups | |
| self.conv_layer = nn.Conv2d(in_channels * 2, out_channels * 2, | |
| kernel_size=1, stride=1, padding=0, | |
| groups=groups, bias=False) | |
| self.bn = nn.BatchNorm2d(out_channels * 2) | |
| self.relu = nn.ReLU(inplace=True) | |
| def forward(self, x): | |
| batch, channels, h, w = x.shape | |
| spec = torch.fft.rfft2(x, norm='backward') | |
| real = spec.real # [B, C, H, W_half] | |
| imag = spec.imag | |
| w_half = w // 2 + 1 | |
| # Interleave real/imag as channels without 5D tensors: | |
| # [B, C, 1, H*W_half] + cat(dim=2) → [B, C, 2, H*W_half] → [B, 2C, H, W_half] | |
| real_flat = real.reshape(batch, channels, 1, h * w_half) | |
| imag_flat = imag.reshape(batch, channels, 1, h * w_half) | |
| ffted = torch.cat([real_flat, imag_flat], dim=2).reshape(batch, channels * 2, h, w_half) | |
| ffted = self.relu(self.bn(self.conv_layer(ffted))) | |
| # De-interleave back: [B, 2C, H, W_half] → [B, C, 2, H*W_half] → split | |
| out_c = ffted.shape[1] // 2 | |
| ffted = ffted.reshape(batch, out_c, 2, h * w_half) | |
| out_r = ffted[:, :, 0, :].reshape(batch, out_c, h, w_half) | |
| out_i = ffted[:, :, 1, :].reshape(batch, out_c, h, w_half) | |
| spec_out = torch.complex(out_r, out_i) | |
| return torch.fft.irfft2(spec_out, s=(h, w), norm='backward') | |
| class SpectralTransform(nn.Module): | |
| def __init__(self, in_channels, out_channels, stride=1, groups=1, | |
| enable_lfu=True, **kwargs): | |
| super().__init__() | |
| self.enable_lfu = enable_lfu | |
| self.downsample = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity() | |
| self.conv1 = nn.Sequential( | |
| nn.Conv2d(in_channels, out_channels // 2, 1, groups=groups, bias=False), | |
| nn.BatchNorm2d(out_channels // 2), | |
| nn.ReLU(inplace=True), | |
| ) | |
| self.fu = FourierUnit(out_channels // 2, out_channels // 2, groups) | |
| if enable_lfu: | |
| self.lfu = FourierUnit(out_channels // 2, out_channels // 2, groups) | |
| self.conv2 = nn.Conv2d(out_channels // 2, out_channels, 1, | |
| groups=groups, bias=False) | |
| def forward(self, x): | |
| x = self.downsample(x) | |
| x = self.conv1(x) | |
| output = self.fu(x) | |
| if self.enable_lfu: | |
| n, c, h, w = x.shape | |
| split_s = h // 2 | |
| xs = torch.cat(torch.split(x[:, :c // 4], split_s, dim=-2), dim=1).contiguous() | |
| xs = torch.cat(torch.split(xs, split_s, dim=-1), dim=1).contiguous() | |
| xs = self.lfu(xs) | |
| xs = xs.repeat(1, 1, 2, 2).contiguous() | |
| else: | |
| xs = 0 | |
| return self.conv2(x + output + xs) | |
| class FFC(nn.Module): | |
| def __init__(self, in_channels, out_channels, kernel_size, | |
| ratio_gin, ratio_gout, stride=1, padding=0, | |
| dilation=1, groups=1, bias=False, enable_lfu=True, | |
| padding_type='reflect', **spectral_kwargs): | |
| super().__init__() | |
| self.ratio_gin = ratio_gin | |
| self.ratio_gout = ratio_gout | |
| in_cg = int(in_channels * ratio_gin) | |
| in_cl = in_channels - in_cg | |
| out_cg = int(out_channels * ratio_gout) | |
| out_cl = out_channels - out_cg | |
| module = nn.Identity if in_cl == 0 or out_cl == 0 else nn.Conv2d | |
| self.convl2l = module(in_cl, out_cl, kernel_size, stride, padding, | |
| dilation, groups, bias, padding_mode=padding_type) | |
| module = nn.Identity if in_cl == 0 or out_cg == 0 else nn.Conv2d | |
| self.convl2g = module(in_cl, out_cg, kernel_size, stride, padding, | |
| dilation, groups, bias, padding_mode=padding_type) | |
| module = nn.Identity if in_cg == 0 or out_cl == 0 else nn.Conv2d | |
| self.convg2l = module(in_cg, out_cl, kernel_size, stride, padding, | |
| dilation, groups, bias, padding_mode=padding_type) | |
| module = nn.Identity if in_cg == 0 or out_cg == 0 else SpectralTransform | |
| self.convg2g = module(in_cg, out_cg, stride, 1 if groups == 1 else groups // 2, | |
| enable_lfu, **spectral_kwargs) | |
| def forward(self, x): | |
| # ratio_gin/ratio_gout are Python floats — tracer follows the correct branch. | |
| if self.ratio_gin == 0: | |
| x_l = x if not isinstance(x, tuple) else x[0] | |
| out_xl = self.convl2l(x_l) if self.ratio_gout != 1 else None | |
| out_xg = self.convl2g(x_l) if self.ratio_gout != 0 else None | |
| else: | |
| x_l, x_g = x | |
| out_xl = (self.convl2l(x_l) + self.convg2l(x_g)) if self.ratio_gout != 1 else None | |
| out_xg = (self.convl2g(x_l) + self.convg2g(x_g)) if self.ratio_gout != 0 else None | |
| return out_xl, out_xg | |
| class FFC_BN_ACT(nn.Module): | |
| def __init__(self, in_channels, out_channels, kernel_size, | |
| ratio_gin, ratio_gout, stride=1, padding=0, | |
| dilation=1, groups=1, bias=False, | |
| norm_layer=nn.BatchNorm2d, activation_layer=nn.Identity, | |
| padding_type='reflect', enable_lfu=True, **kwargs): | |
| super().__init__() | |
| self.ffc = FFC(in_channels, out_channels, kernel_size, | |
| ratio_gin, ratio_gout, stride, padding, dilation, | |
| groups, bias, enable_lfu, padding_type, **kwargs) | |
| global_channels = int(out_channels * ratio_gout) | |
| self.bn_l = nn.Identity() if ratio_gout == 1 else norm_layer(out_channels - global_channels) | |
| self.bn_g = nn.Identity() if ratio_gout == 0 else norm_layer(global_channels) | |
| self.act_l = nn.Identity() if ratio_gout == 1 else activation_layer(inplace=True) | |
| self.act_g = nn.Identity() if ratio_gout == 0 else activation_layer(inplace=True) | |
| def forward(self, x): | |
| x_l, x_g = self.ffc(x) | |
| out_l = self.act_l(self.bn_l(x_l)) if x_l is not None else None | |
| out_g = self.act_g(self.bn_g(x_g)) if x_g is not None else None | |
| return out_l, out_g | |
| class FFCResnetBlock(nn.Module): | |
| def __init__(self, dim, padding_type, norm_layer, activation_layer=nn.ReLU, | |
| dilation=1, inline=False, **conv_kwargs): | |
| super().__init__() | |
| self.conv1 = FFC_BN_ACT(dim, dim, 3, padding=dilation, dilation=dilation, | |
| norm_layer=norm_layer, activation_layer=activation_layer, | |
| padding_type=padding_type, **conv_kwargs) | |
| self.conv2 = FFC_BN_ACT(dim, dim, 3, padding=dilation, dilation=dilation, | |
| norm_layer=norm_layer, activation_layer=activation_layer, | |
| padding_type=padding_type, **conv_kwargs) | |
| self.inline = inline | |
| def forward(self, x): | |
| if isinstance(x, tuple): | |
| x_l, x_g = x | |
| else: | |
| x_l, x_g = x, None | |
| id_l, id_g = x_l, x_g | |
| x_l, x_g = self.conv1((x_l, x_g)) | |
| x_l, x_g = self.conv2((x_l, x_g)) | |
| out_l = id_l + x_l if x_l is not None else None | |
| out_g = id_g + x_g if x_g is not None else None | |
| return out_l, out_g | |
| class ConcatTupleLayer(nn.Module): | |
| def forward(self, x): | |
| x_l, x_g = x | |
| if x_g is None: | |
| return x_l | |
| return torch.cat([x_l, x_g], dim=1) | |
| class LaMaModel(nn.Module): | |
| """LaMa FFCResNetGenerator with separate image/mask inputs.""" | |
| def __init__(self): | |
| super().__init__() | |
| ngf = 64 | |
| n_downsampling = 3 | |
| n_blocks = 18 | |
| norm_layer = nn.BatchNorm2d | |
| resnet_kw = dict(ratio_gin=0.75, ratio_gout=0.75, enable_lfu=False) | |
| init_kw = dict(ratio_gin=0, ratio_gout=0, enable_lfu=False) | |
| down_kw = dict(ratio_gin=0, ratio_gout=0, enable_lfu=False) | |
| layers = [ | |
| nn.ReflectionPad2d(3), | |
| FFC_BN_ACT(4, ngf, 7, padding=0, norm_layer=norm_layer, | |
| activation_layer=nn.ReLU, **init_kw), | |
| ] | |
| # Downsample | |
| for i in range(n_downsampling): | |
| mult = 2 ** i | |
| kw = dict(down_kw) | |
| if i == n_downsampling - 1: | |
| kw['ratio_gout'] = resnet_kw['ratio_gin'] | |
| layers.append(FFC_BN_ACT( | |
| min(1024, ngf * mult), min(1024, ngf * mult * 2), | |
| 3, stride=2, padding=1, | |
| norm_layer=norm_layer, activation_layer=nn.ReLU, **kw)) | |
| # ResNet blocks | |
| mult = 2 ** n_downsampling | |
| feats = min(1024, ngf * mult) | |
| for _ in range(n_blocks): | |
| layers.append(FFCResnetBlock(feats, padding_type='reflect', | |
| norm_layer=norm_layer, | |
| activation_layer=nn.ReLU, | |
| **resnet_kw)) | |
| layers.append(ConcatTupleLayer()) | |
| # Upsample | |
| for i in range(n_downsampling): | |
| mult = 2 ** (n_downsampling - i) | |
| layers += [ | |
| nn.ConvTranspose2d(min(1024, ngf * mult), | |
| min(1024, ngf * mult // 2), | |
| 3, stride=2, padding=1, output_padding=1), | |
| norm_layer(min(1024, ngf * mult // 2)), | |
| nn.ReLU(True), | |
| ] | |
| layers += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, 3, 7), nn.Sigmoid()] | |
| self.generator = nn.Sequential(*layers) | |
| def forward(self, image, mask): | |
| masked_img = image * (1.0 - mask) | |
| x = torch.cat([masked_img, mask], dim=1) | |
| return self.generator(x) | |
| def load_weights(model, sf_path): | |
| """Load safetensors weights, mapping key prefixes. | |
| Safetensors keys: model.N.xxx (nn.Sequential index) | |
| Our model keys: generator.N.xxx | |
| """ | |
| from safetensors.torch import load_file | |
| state = load_file(sf_path) | |
| mapped = OrderedDict() | |
| for k, v in state.items(): | |
| # model.N.xxx → generator.N.xxx | |
| nk = k.replace("model.", "generator.", 1) | |
| mapped[nk] = v | |
| result = model.load_state_dict(mapped, strict=False) | |
| if result.missing_keys: | |
| real_missing = [k for k in result.missing_keys if 'num_batches_tracked' not in k] | |
| if real_missing: | |
| print(f"WARNING: {len(real_missing)} genuinely missing keys:") | |
| for k in real_missing[:20]: | |
| print(f" {k}") | |
| if result.unexpected_keys: | |
| print(f"NOTE: {len(result.unexpected_keys)} unexpected keys (ignored)") | |
| loaded = len(state) - len(result.unexpected_keys) | |
| print(f"Loaded {loaded}/{len(state)} weights") | |
| return model | |
| def resolve_safetensors_path(local_path): | |
| if local_path: | |
| return local_path | |
| from huggingface_hub import hf_hub_download | |
| print("Downloading safetensors from mayocream/lama-manga...") | |
| return hf_hub_download("mayocream/lama-manga", "lama-manga.safetensors") | |
| def verify_coreml_session(model_path): | |
| try: | |
| import onnxruntime as ort | |
| except Exception as e: | |
| print(f"CoreML verify skipped: onnxruntime unavailable ({e})") | |
| return | |
| providers = ort.get_available_providers() | |
| if "CoreMLExecutionProvider" not in providers: | |
| print(f"CoreML verify skipped: CoreMLExecutionProvider not available ({providers})") | |
| return | |
| ml_opts = { | |
| "ModelFormat": "MLProgram", | |
| "MLComputeUnits": "CPUAndNeuralEngine", | |
| } | |
| sess = ort.InferenceSession( | |
| model_path, | |
| providers=[("CoreMLExecutionProvider", ml_opts), "CPUExecutionProvider"], | |
| ) | |
| got_coreml = "CoreMLExecutionProvider" in sess.get_providers() | |
| if got_coreml: | |
| print("CoreML MLProgram session init: OK") | |
| else: | |
| raise RuntimeError( | |
| f"CoreML MLProgram init fallback detected; active providers={sess.get_providers()}" | |
| ) | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Export LaMa-manga ONNX for CoreML/ANE") | |
| parser.add_argument("-o", "--output", default="models/lama-manga.onnx") | |
| parser.add_argument("--safetensors", default=None, help="Local safetensors path") | |
| parser.add_argument("--verify", action="store_true", help="Run test inference") | |
| parser.add_argument( | |
| "--verify-coreml", | |
| action="store_true", | |
| help="Try creating an ONNX Runtime CoreML MLProgram session after export", | |
| ) | |
| args = parser.parse_args() | |
| sf_path = resolve_safetensors_path(args.safetensors) | |
| # 1. Build model | |
| print("Building LaMa model (rank-4-safe DFT)...") | |
| model = LaMaModel() | |
| # 2. Load weights | |
| print(f"Loading weights from {sf_path}") | |
| model = load_weights(model, sf_path) | |
| model.eval() | |
| # 3. Optional: verify inference | |
| if args.verify: | |
| with torch.no_grad(): | |
| img = torch.randn(1, 3, 512, 512) | |
| mask = torch.zeros(1, 1, 512, 512) | |
| mask[:, :, 100:200, 100:200] = 1.0 | |
| out = model(img, mask) | |
| print(f"Test output: shape={out.shape}, range=[{out.min():.3f}, {out.max():.3f}]") | |
| # 4. Export ONNX | |
| print("Exporting ONNX...") | |
| dummy_img = torch.randn(1, 3, 512, 512) | |
| dummy_mask = torch.zeros(1, 1, 512, 512) | |
| # LaMa has 3× stride-2 downsampling; input must be multiple of 8. | |
| # Rust inpaint.rs already pads to multiple of 8. | |
| h_base = torch.export.Dim("h_blocks", min=8, max=512) | |
| w_base = torch.export.Dim("w_blocks", min=8, max=512) | |
| height = h_base * 8 | |
| width = w_base * 8 | |
| dynamic_shapes = { | |
| "image": {2: height, 3: width}, | |
| "mask": {2: height, 3: width}, | |
| } | |
| with torch.no_grad(): | |
| torch.onnx.export( | |
| model, | |
| (dummy_img, dummy_mask), | |
| args.output, | |
| input_names=["image", "mask"], | |
| output_names=["output"], | |
| dynamic_shapes=dynamic_shapes, | |
| ) | |
| # 5. Rewrite DFT(inverse=1, onesided=1) → hermitian pad + DFT(inverse=1, onesided=0). | |
| # ORT does not support inverse+onesided combo. The decomposition is: | |
| # 1. Take half-spectrum input [B,C,H,W//2+1,2] | |
| # 2. Mirror bins [1:-1] with conjugate to reconstruct full spectrum | |
| # 3. Run DFT(inverse=1, onesided=0) on the full spectrum | |
| # This is the standard irfft decomposition, just done at graph level. | |
| import onnx | |
| from onnx import helper, TensorProto | |
| model_onnx = onnx.load(args.output, load_external_data=True) | |
| graph = model_onnx.graph | |
| fixed = 0 | |
| nodes_to_add = [] | |
| nodes_to_remove = [] | |
| for node in list(graph.node): | |
| if node.op_type != "DFT": | |
| continue | |
| attrs = {a.name: a.i for a in node.attribute} | |
| if not (attrs.get("inverse", 0) == 1 and attrs.get("onesided", 0) == 1): | |
| continue | |
| inp = node.input[0] # half-spectrum [..., W//2+1, 2] | |
| out = node.output[0] | |
| uid = f"_irfft_{fixed}" | |
| # Slice middle bins [1:-1] along axis -2 | |
| starts = helper.make_tensor(f"starts{uid}", TensorProto.INT64, [1], [1]) | |
| ends = helper.make_tensor(f"ends{uid}", TensorProto.INT64, [1], [-1]) | |
| axes = helper.make_tensor(f"axes{uid}", TensorProto.INT64, [1], [-2]) | |
| graph.initializer.extend([starts, ends, axes]) | |
| mid_name = f"mid{uid}" | |
| nodes_to_add.append(helper.make_node( | |
| "Slice", [inp, starts.name, ends.name, axes.name], [mid_name])) | |
| # Flip middle bins along axis -2 via Slice with step=-1 | |
| flip_starts = helper.make_tensor(f"flip_s{uid}", TensorProto.INT64, [1], [-1]) | |
| flip_ends = helper.make_tensor(f"flip_e{uid}", TensorProto.INT64, [1], [-2147483648]) # INT_MIN | |
| flip_axes = helper.make_tensor(f"flip_ax{uid}", TensorProto.INT64, [1], [-2]) | |
| flip_steps = helper.make_tensor(f"flip_st{uid}", TensorProto.INT64, [1], [-1]) | |
| graph.initializer.extend([flip_starts, flip_ends, flip_axes, flip_steps]) | |
| flip_name = f"flip{uid}" | |
| nodes_to_add.append(helper.make_node( | |
| "Slice", [mid_name, flip_starts.name, flip_ends.name, flip_axes.name, flip_steps.name], | |
| [flip_name])) | |
| # Conjugate = negate imag part: [..., 2] → flip sign of channel 1 | |
| conj_scale = helper.make_tensor(f"conj{uid}", TensorProto.FLOAT, [2], [1.0, -1.0]) | |
| graph.initializer.append(conj_scale) | |
| conj_name = f"conj_out{uid}" | |
| nodes_to_add.append(helper.make_node( | |
| "Mul", [flip_name, conj_scale.name], [conj_name])) | |
| # Concat: [half, conj_mirror] along axis -2 | |
| full_name = f"full{uid}" | |
| nodes_to_add.append(helper.make_node( | |
| "Concat", [inp, conj_name], [full_name], axis=-2)) | |
| # DFT(inverse=1, onesided=0) on full spectrum | |
| new_dft = helper.make_node( | |
| "DFT", [full_name], [out], | |
| inverse=1, onesided=0, name=f"idft{uid}") | |
| # Copy dft_length input if present | |
| if len(node.input) > 1 and node.input[1]: | |
| new_dft.input.append(node.input[1]) | |
| nodes_to_add.append(new_dft) | |
| nodes_to_remove.append(node) | |
| fixed += 1 | |
| for n in nodes_to_remove: | |
| graph.node.remove(n) | |
| graph.node.extend(nodes_to_add) | |
| print(f"Rewrote {fixed} DFT(inverse+onesided) → hermitian pad + DFT(inverse)") | |
| # 6. Rewrite Shape(start=..., end=...) to Shape + Slice. | |
| # CoreML EP's MLProgram path currently rejects Shape with end == rank | |
| # (e.g. start=3,end=4 on rank-4 tensors) with: | |
| # "axis 4 not in valid range [-4,3]". | |
| shape_rewrites = 0 | |
| rewritten_nodes = [] | |
| for node in list(graph.node): | |
| if node.op_type != "Shape": | |
| rewritten_nodes.append(node) | |
| continue | |
| attrs = {a.name: a.i for a in node.attribute} | |
| if "start" not in attrs and "end" not in attrs: | |
| rewritten_nodes.append(node) | |
| continue | |
| start = attrs.get("start", 0) | |
| end = attrs.get("end", 9223372036854775807) | |
| uid = f"_shapefix_{shape_rewrites}" | |
| full_shape_out = f"{node.output[0]}{uid}_all" | |
| shape_name = f"{node.name}_all" if node.name else f"shape_all{uid}" | |
| slice_name = f"{node.name}_slice" if node.name else f"shape_slice{uid}" | |
| rewritten_nodes.append( | |
| helper.make_node("Shape", [node.input[0]], [full_shape_out], name=shape_name) | |
| ) | |
| starts = helper.make_tensor(f"starts{uid}", TensorProto.INT64, [1], [start]) | |
| ends = helper.make_tensor(f"ends{uid}", TensorProto.INT64, [1], [end]) | |
| axes = helper.make_tensor(f"axes{uid}", TensorProto.INT64, [1], [0]) | |
| graph.initializer.extend([starts, ends, axes]) | |
| rewritten_nodes.append( | |
| helper.make_node( | |
| "Slice", | |
| [full_shape_out, starts.name, ends.name, axes.name], | |
| list(node.output), | |
| name=slice_name, | |
| ) | |
| ) | |
| shape_rewrites += 1 | |
| if shape_rewrites: | |
| del graph.node[:] | |
| graph.node.extend(rewritten_nodes) | |
| print(f"Rewrote {shape_rewrites} Shape(start/end) nodes → Shape+Slice") | |
| # Re-embed weights into single file | |
| for tensor in graph.initializer: | |
| tensor.ClearField("data_location") | |
| onnx.save(model_onnx, args.output, save_as_external_data=False) | |
| ops = Counter(n.op_type for n in model_onnx.graph.node) | |
| print(f"\nFinal: {sum(ops.values())} nodes, {len(ops)} unique ops") | |
| for op, count in ops.most_common(): | |
| print(f" {op}: {count}") | |
| for t in list(model_onnx.graph.input) + list(model_onnx.graph.output): | |
| shape = t.type.tensor_type.shape | |
| dims = [d.dim_param or str(d.dim_value) for d in shape.dim] | |
| print(f" {t.name}: [{', '.join(dims)}]") | |
| size_mb = os.path.getsize(args.output) / (1024 * 1024) | |
| print(f"\nSaved: {args.output} ({size_mb:.1f} MB)") | |
| if args.verify_coreml: | |
| verify_coreml_session(args.output) | |
| if __name__ == "__main__": | |
| main() | |
Xet Storage Details
- Size:
- 20.8 kB
- Xet hash:
- 88117aa924abf714288926d0be7e71ee690546bb72c30d06b0e66ab679cef1cb
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.