File size: 1,776 Bytes
eef8873
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import os
import logging
import torch

from src.config import DEVICE, NUM_CLASSES, CHECKPOINT_DIR
from src.models.fusion_model import FusionClassifier

logger = logging.getLogger(__name__)

INPUT_CHECKPOINT = CHECKPOINT_DIR / "best_fusion_model.pt"
OUTPUT_CHECKPOINT = CHECKPOINT_DIR / "best_fusion_model_fp16.pt"


def convert_fusion_to_fp16():
    logger.info("Initializing Fusion model for FP16 conversion...")

    if not INPUT_CHECKPOINT.exists():
        raise FileNotFoundError(
            f"Fusion checkpoint not found: {INPUT_CHECKPOINT}"
        )

    model = FusionClassifier(
        num_classes=NUM_CLASSES
    ).to(DEVICE)

    logger.info(f"Loading checkpoint from: {INPUT_CHECKPOINT}")

    checkpoint = torch.load(
        INPUT_CHECKPOINT,
        map_location=DEVICE
    )

    if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
        model.load_state_dict(checkpoint["model_state_dict"])
    else:
        model.load_state_dict(checkpoint)

    logger.info("Model weights loaded successfully.")

    model.eval()

    logger.info("Converting model to FP16...")

    model = model.half()

    torch.save(
        model.state_dict(),
        OUTPUT_CHECKPOINT
    )

    size_mb = os.path.getsize(OUTPUT_CHECKPOINT) / (1024 * 1024)

    logger.info(f"FP16 model saved at: {OUTPUT_CHECKPOINT}")
    logger.info(f"FP16 model size: {size_mb:.2f} MB")

    return OUTPUT_CHECKPOINT


if __name__ == "__main__":
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s - %(levelname)s - %(message)s"
    )

    fp16_path = convert_fusion_to_fp16()

    print("\nFusion FP16 conversion completed successfully.")
    print(f"Saved model: {fp16_path}")