File size: 3,124 Bytes
2979822
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
"""Extract inference-ready weights from training checkpoint."""

import torch
from collections import OrderedDict
import os
import sys
import argparse

sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))

from src.minifasv2.model import MultiFTNet
from src.minifasv2.config import get_kernel


def extract_model_weights(checkpoint_path, output_path, input_size=128):
    print(f"Loading checkpoint: {checkpoint_path}")

    device = "cpu"
    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)

    if "model_state_dict" in checkpoint:
        state_dict = checkpoint["model_state_dict"]
    elif "state_dict" in checkpoint:
        state_dict = checkpoint["state_dict"]
    else:
        state_dict = checkpoint

    clean_state_dict = OrderedDict()
    for key, value in state_dict.items():
        if "FTGenerator" in key:
            continue

        new_key = key
        if new_key.startswith("module."):
            new_key = new_key[7:]
        new_key = new_key.replace("model.prob", "model.logits")
        new_key = new_key.replace(".prob", ".logits")
        new_key = new_key.replace("model.drop", "model.dropout")
        new_key = new_key.replace(".drop", ".dropout")

        clean_state_dict[new_key] = value

    kernel_size = get_kernel(input_size, input_size)
    model = MultiFTNet(
        num_channels=3,
        num_classes=2,
        embedding_size=128,
        conv6_kernel=kernel_size,
    )

    model.load_state_dict(clean_state_dict, strict=False)

    print(f"Saving clean model to: {output_path}")
    torch.save(
        {
            "model_state_dict": clean_state_dict,
            "input_size": input_size,
            "num_classes": 2,
            "architecture": "MiniFASNetV2SE",
        },
        output_path,
    )

    size_mb = os.path.getsize(output_path) / (1024 * 1024)
    original_size = os.path.getsize(checkpoint_path) / (1024 * 1024)
    reduction = (1 - size_mb / original_size) * 100

    print(f"[OK] Clean model saved: {size_mb:.2f} MB")
    print(f"     Original size: {original_size:.2f} MB")
    print(f"     Size reduction: {reduction:.1f}%")

    return output_path


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Extract clean model weights from epoch checkpoint"
    )
    parser.add_argument(
        "epoch_checkpoint",
        type=str,
        help="Path to epoch checkpoint",
    )
    parser.add_argument(
        "--output",
        type=str,
        default=None,
        help="Output path for best model (default: best_model.pth in models/)",
    )
    parser.add_argument(
        "--input_size", type=int, default=128, help="Input image size (default: 128)"
    )

    args = parser.parse_args()

    assert os.path.isfile(
        args.epoch_checkpoint
    ), f"Checkpoint not found: {args.epoch_checkpoint}"

    if args.output is None:
        os.makedirs("models", exist_ok=True)
        args.output = "models/best_model.pth"

    extract_model_weights(args.epoch_checkpoint, args.output, args.input_size)
    print(f"\n[OK] Best model ready: {args.output}")