Spaces:
Sleeping
Sleeping
File size: 1,776 Bytes
eef8873 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 | import os
import logging
import torch
from src.config import DEVICE, NUM_CLASSES, CHECKPOINT_DIR
from src.models.fusion_model import FusionClassifier
logger = logging.getLogger(__name__)
INPUT_CHECKPOINT = CHECKPOINT_DIR / "best_fusion_model.pt"
OUTPUT_CHECKPOINT = CHECKPOINT_DIR / "best_fusion_model_fp16.pt"
def convert_fusion_to_fp16():
logger.info("Initializing Fusion model for FP16 conversion...")
if not INPUT_CHECKPOINT.exists():
raise FileNotFoundError(
f"Fusion checkpoint not found: {INPUT_CHECKPOINT}"
)
model = FusionClassifier(
num_classes=NUM_CLASSES
).to(DEVICE)
logger.info(f"Loading checkpoint from: {INPUT_CHECKPOINT}")
checkpoint = torch.load(
INPUT_CHECKPOINT,
map_location=DEVICE
)
if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
model.load_state_dict(checkpoint["model_state_dict"])
else:
model.load_state_dict(checkpoint)
logger.info("Model weights loaded successfully.")
model.eval()
logger.info("Converting model to FP16...")
model = model.half()
torch.save(
model.state_dict(),
OUTPUT_CHECKPOINT
)
size_mb = os.path.getsize(OUTPUT_CHECKPOINT) / (1024 * 1024)
logger.info(f"FP16 model saved at: {OUTPUT_CHECKPOINT}")
logger.info(f"FP16 model size: {size_mb:.2f} MB")
return OUTPUT_CHECKPOINT
if __name__ == "__main__":
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s"
)
fp16_path = convert_fusion_to_fp16()
print("\nFusion FP16 conversion completed successfully.")
print(f"Saved model: {fp16_path}") |