File size: 804 Bytes
36c78b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import os
import gzip
import torch
import torch.nn as nn


def save_model(model: torch.nn.Module, path: str) -> None:
    """Save a model using gzip compression."""
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with gzip.open(path, 'wb') as f:
        torch.save(model, f)


def load_model(path: str) -> torch.nn.Module:
    """Load a model saved with ``save_model``."""
    with gzip.open(path, 'rb') as f:
        model = torch.load(f, map_location="cpu", weights_only=False)
    return model


def set_dropout(model: torch.nn.Module, p: float) -> None:
    """Set dropout probability ``p`` for all dropout layers in ``model``."""
    for module in model.modules():
        if isinstance(module, nn.Dropout):
            module.p = p


__all__ = ["save_model", "load_model", "set_dropout"]