|
|
|
|
|
""" |
|
|
Universal Model Converter for DeepDream-MLX. |
|
|
Converts PyTorch (.pth) and Torch7 (.t7) models to MLX (.npz). |
|
|
Also supports auto-downloading standard Places365 models. |
|
|
Defaults to float16 for optimal performance on Apple Silicon. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import argparse |
|
|
import glob |
|
|
import numpy as np |
|
|
import torch |
|
|
import torchvision.models as models |
|
|
from torch.hub import download_url_to_file |
|
|
|
|
|
|
|
|
try: |
|
|
import torchfile |
|
|
except ImportError: |
|
|
torchfile = None |
|
|
|
|
|
|
|
|
PLACES365_URLS = { |
|
|
"alexnet": "http://places2.csail.mit.edu/models_places365/alexnet_places365.pth.tar", |
|
|
"resnet50": "http://places2.csail.mit.edu/models_places365/resnet50_places365.pth.tar", |
|
|
"vgg16": "http://places2.csail.mit.edu/models_places365/vgg16_places365.pth.tar", |
|
|
"googlenet": "http://places2.csail.mit.edu/models_places365/googlenet_places365.pth.tar" |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
def convert_tensor(tensor, target_dtype=np.float16): |
|
|
"""Converts a tensor/array to the target numpy dtype.""" |
|
|
if isinstance(tensor, torch.Tensor): |
|
|
return tensor.cpu().detach().numpy().astype(target_dtype) |
|
|
elif isinstance(tensor, np.ndarray): |
|
|
return tensor.astype(target_dtype) |
|
|
else: |
|
|
return np.array(tensor).astype(target_dtype) |
|
|
|
|
|
def clean_state_dict(state_dict): |
|
|
""" |
|
|
Flattens the state dictionary and removes common prefix artifacts |
|
|
like 'module.' from DataParallel wrapping. |
|
|
""" |
|
|
new_dict = {} |
|
|
for k, v in state_dict.items(): |
|
|
|
|
|
name = k.replace("module.", "") |
|
|
new_dict[name] = convert_tensor(v) |
|
|
return new_dict |
|
|
|
|
|
def get_places365_model_skeleton(arch): |
|
|
"""Returns a standard PyTorch model structure for Places365.""" |
|
|
if arch == "alexnet": |
|
|
return models.alexnet(num_classes=365) |
|
|
elif arch == "resnet50": |
|
|
return models.resnet50(num_classes=365) |
|
|
elif arch == "vgg16": |
|
|
return models.vgg16(num_classes=365) |
|
|
elif arch == "googlenet": |
|
|
return models.googlenet(num_classes=365, aux_logits=False) |
|
|
else: |
|
|
raise ValueError(f"Unknown architecture: {arch}") |
|
|
|
|
|
|
|
|
|
|
|
def convert_torch7(filepath, target_dir): |
|
|
if torchfile is None: |
|
|
print(f"⚠️ Skipping {filepath}: 'torchfile' not installed. Run `pip install torchfile`.") |
|
|
return |
|
|
|
|
|
print(f"Processing Torch7 file: {filepath}") |
|
|
try: |
|
|
model_obj = torchfile.load(filepath) |
|
|
converted_state = {} |
|
|
|
|
|
def extract_layers(layer, prefix=""): |
|
|
if hasattr(layer, 'weight') and layer.weight is not None: |
|
|
converted_state[f"{prefix}.weight"] = convert_tensor(layer.weight) |
|
|
if hasattr(layer, 'bias') and layer.bias is not None: |
|
|
converted_state[f"{prefix}.bias"] = convert_tensor(layer.bias) |
|
|
|
|
|
if hasattr(layer, 'modules') and layer.modules: |
|
|
for i, sublayer in enumerate(layer.modules): |
|
|
|
|
|
next_prefix = f"{prefix}.{i}" if prefix else f"{i}" |
|
|
extract_layers(sublayer, next_prefix) |
|
|
|
|
|
extract_layers(model_obj) |
|
|
|
|
|
if not converted_state: |
|
|
print(f"❌ No weights found in {filepath}.") |
|
|
return |
|
|
|
|
|
name_base = os.path.splitext(os.path.basename(filepath))[0] |
|
|
out_path = os.path.join(target_dir, f"{name_base}_t7_mlx.npz") |
|
|
np.savez(out_path, **converted_state) |
|
|
print(f"✅ Saved {out_path} ({len(converted_state)} tensors)") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Failed to convert {filepath}: {e}") |
|
|
|
|
|
def convert_pytorch(filepath, target_dir): |
|
|
print(f"Processing PyTorch file: {filepath}") |
|
|
try: |
|
|
checkpoint = torch.load(filepath, map_location="cpu") |
|
|
|
|
|
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: |
|
|
state_dict = checkpoint['state_dict'] |
|
|
elif isinstance(checkpoint, dict): |
|
|
state_dict = checkpoint |
|
|
else: |
|
|
print(f"❌ Unknown checkpoint format in {filepath}") |
|
|
return |
|
|
|
|
|
clean_dict = clean_state_dict(state_dict) |
|
|
|
|
|
name_base = os.path.splitext(os.path.basename(filepath))[0] |
|
|
|
|
|
if name_base.endswith(".pth"): |
|
|
name_base = os.path.splitext(name_base)[0] |
|
|
|
|
|
out_path = os.path.join(target_dir, f"{name_base}_mlx.npz") |
|
|
np.savez(out_path, **clean_dict) |
|
|
|
|
|
size_mb = os.path.getsize(out_path) / (1024*1024) |
|
|
print(f"✅ Saved {out_path} ({size_mb:.1f} MB)") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Failed to convert {filepath}: {e}") |
|
|
|
|
|
def download_and_convert_places365(arch, download_dir, target_dir): |
|
|
url = PLACES365_URLS.get(arch) |
|
|
if not url: |
|
|
print(f"No URL for {arch}") |
|
|
return |
|
|
|
|
|
filename = os.path.join(download_dir, os.path.basename(url)) |
|
|
|
|
|
|
|
|
if not os.path.exists(filename): |
|
|
print(f"Downloading {arch} from {url}...") |
|
|
try: |
|
|
download_url_to_file(url, filename) |
|
|
except Exception as e: |
|
|
print(f"Download failed: {e}") |
|
|
return |
|
|
else: |
|
|
print(f"Found cached {filename}") |
|
|
|
|
|
|
|
|
print(f"Loading {arch} into PyTorch structure...") |
|
|
try: |
|
|
model = get_places365_model_skeleton(arch) |
|
|
checkpoint = torch.load(filename, map_location="cpu") |
|
|
state_dict = checkpoint['state_dict'] if 'state_dict' in checkpoint else checkpoint |
|
|
|
|
|
|
|
|
new_state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} |
|
|
try: |
|
|
model.load_state_dict(new_state_dict, strict=True) |
|
|
except: |
|
|
model.load_state_dict(new_state_dict, strict=False) |
|
|
|
|
|
|
|
|
model.eval() |
|
|
final_dict = clean_state_dict(model.state_dict()) |
|
|
out_path = os.path.join(target_dir, f"{arch}_places365_mlx.npz") |
|
|
np.savez(out_path, **final_dict) |
|
|
print(f"✅ Saved {out_path}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Failed to process {arch}: {e}") |
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="DeepDream-MLX Model Converter") |
|
|
parser.add_argument("--scan", default="toConvert", help="Directory to scan for local files") |
|
|
parser.add_argument("--download", choices=["alexnet", "resnet50", "vgg16", "googlenet", "all"], |
|
|
help="Download and convert specific Places365 models") |
|
|
parser.add_argument("--dest", default=".", help="Output directory for .npz files") |
|
|
args = parser.parse_args() |
|
|
|
|
|
if not os.path.exists(args.dest): |
|
|
os.makedirs(args.dest) |
|
|
|
|
|
|
|
|
if args.download: |
|
|
if not os.path.exists(args.scan): |
|
|
os.makedirs(args.scan) |
|
|
|
|
|
targets = ["alexnet", "resnet50", "vgg16", "googlenet"] if args.download == "all" else [args.download] |
|
|
for t in targets: |
|
|
download_and_convert_places365(t, args.scan, args.dest) |
|
|
|
|
|
|
|
|
if os.path.exists(args.scan): |
|
|
print(f"\nScanning '{args.scan}' for local models...") |
|
|
files = glob.glob(os.path.join(args.scan, "*")) |
|
|
for f in files: |
|
|
if os.path.isdir(f): continue |
|
|
ext = os.path.splitext(f)[1].lower() |
|
|
|
|
|
if ext == ".t7": |
|
|
convert_torch7(f, args.dest) |
|
|
elif ext in [".pth", ".pt", ".tar", ".pkl"]: |
|
|
|
|
|
|
|
|
convert_pytorch(f, args.dest) |
|
|
elif ext in [".caffemodel"]: |
|
|
print(f"⚠️ Skipping Caffe model {os.path.basename(f)} (Not supported)") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|