"""Extract inference-ready weights from training checkpoint.""" import torch from collections import OrderedDict import os import sys import argparse sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) from src.minifasv2.model import MultiFTNet from src.minifasv2.config import get_kernel def extract_model_weights(checkpoint_path, output_path, input_size=128): print(f"Loading checkpoint: {checkpoint_path}") device = "cpu" checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) if "model_state_dict" in checkpoint: state_dict = checkpoint["model_state_dict"] elif "state_dict" in checkpoint: state_dict = checkpoint["state_dict"] else: state_dict = checkpoint clean_state_dict = OrderedDict() for key, value in state_dict.items(): if "FTGenerator" in key: continue new_key = key if new_key.startswith("module."): new_key = new_key[7:] new_key = new_key.replace("model.prob", "model.logits") new_key = new_key.replace(".prob", ".logits") new_key = new_key.replace("model.drop", "model.dropout") new_key = new_key.replace(".drop", ".dropout") clean_state_dict[new_key] = value kernel_size = get_kernel(input_size, input_size) model = MultiFTNet( num_channels=3, num_classes=2, embedding_size=128, conv6_kernel=kernel_size, ) model.load_state_dict(clean_state_dict, strict=False) print(f"Saving clean model to: {output_path}") torch.save( { "model_state_dict": clean_state_dict, "input_size": input_size, "num_classes": 2, "architecture": "MiniFASNetV2SE", }, output_path, ) size_mb = os.path.getsize(output_path) / (1024 * 1024) original_size = os.path.getsize(checkpoint_path) / (1024 * 1024) reduction = (1 - size_mb / original_size) * 100 print(f"[OK] Clean model saved: {size_mb:.2f} MB") print(f" Original size: {original_size:.2f} MB") print(f" Size reduction: {reduction:.1f}%") return output_path if __name__ == "__main__": parser = argparse.ArgumentParser( description="Extract clean model weights from epoch checkpoint" ) parser.add_argument( "epoch_checkpoint", type=str, help="Path to epoch checkpoint", ) parser.add_argument( "--output", type=str, default=None, help="Output path for best model (default: best_model.pth in models/)", ) parser.add_argument( "--input_size", type=int, default=128, help="Input image size (default: 128)" ) args = parser.parse_args() assert os.path.isfile( args.epoch_checkpoint ), f"Checkpoint not found: {args.epoch_checkpoint}" if args.output is None: os.makedirs("models", exist_ok=True) args.output = "models/best_model.pth" extract_model_weights(args.epoch_checkpoint, args.output, args.input_size) print(f"\n[OK] Best model ready: {args.output}")