|
|
import os |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.nn import init |
|
|
|
|
|
from utils1.config import CONFIGCLASS |
|
|
from utils1.utils import get_network |
|
|
from utils1.warmup import GradualWarmupScheduler |
|
|
|
|
|
|
|
|
class BaseModel(nn.Module): |
|
|
def __init__(self, cfg: CONFIGCLASS): |
|
|
super().__init__() |
|
|
self.cfg = cfg |
|
|
self.total_steps = 0 |
|
|
self.isTrain = cfg.isTrain |
|
|
self.save_dir = cfg.ckpt_dir |
|
|
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
|
|
self.model:nn.Module |
|
|
self.model=nn.Module.to(self.device) |
|
|
|
|
|
|
|
|
self.optimizer: torch.optim.Optimizer |
|
|
|
|
|
def save_networks(self, epoch: int): |
|
|
save_filename = f"model_epoch_{epoch}.pth" |
|
|
save_path = os.path.join(self.save_dir, save_filename) |
|
|
|
|
|
|
|
|
state_dict = { |
|
|
"model": self.model.state_dict(), |
|
|
"optimizer": self.optimizer.state_dict(), |
|
|
"total_steps": self.total_steps, |
|
|
} |
|
|
|
|
|
torch.save(state_dict, save_path) |
|
|
|
|
|
|
|
|
def load_networks(self, epoch: int): |
|
|
load_filename = f"model_epoch_{epoch}.pth" |
|
|
load_path = os.path.join(self.save_dir, load_filename) |
|
|
|
|
|
if epoch==0: |
|
|
|
|
|
load_path="checkpoints/optical.pth" |
|
|
print("loading optical path") |
|
|
else : |
|
|
print(f"loading the model from {load_path}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
state_dict = torch.load(load_path, map_location=self.device) |
|
|
if hasattr(state_dict, "_metadata"): |
|
|
del state_dict._metadata |
|
|
|
|
|
self.model.load_state_dict(state_dict["model"]) |
|
|
self.total_steps = state_dict["total_steps"] |
|
|
|
|
|
if self.isTrain and not self.cfg.new_optim: |
|
|
self.optimizer.load_state_dict(state_dict["optimizer"]) |
|
|
|
|
|
for state in self.optimizer.state.values(): |
|
|
for k, v in state.items(): |
|
|
if torch.is_tensor(v): |
|
|
state[k] = v.to(self.device) |
|
|
|
|
|
for g in self.optimizer.param_groups: |
|
|
g["lr"] = self.cfg.lr |
|
|
|
|
|
def eval(self): |
|
|
self.model.eval() |
|
|
|
|
|
def test(self): |
|
|
with torch.no_grad(): |
|
|
self.forward() |
|
|
|
|
|
|
|
|
def init_weights(net: nn.Module, init_type="normal", gain=0.02): |
|
|
def init_func(m: nn.Module): |
|
|
classname = m.__class__.__name__ |
|
|
if hasattr(m, "weight") and (classname.find("Conv") != -1 or classname.find("Linear") != -1): |
|
|
if init_type == "normal": |
|
|
init.normal_(m.weight.data, 0.0, gain) |
|
|
elif init_type == "xavier": |
|
|
init.xavier_normal_(m.weight.data, gain=gain) |
|
|
elif init_type == "kaiming": |
|
|
init.kaiming_normal_(m.weight.data, a=0, mode="fan_in") |
|
|
elif init_type == "orthogonal": |
|
|
init.orthogonal_(m.weight.data, gain=gain) |
|
|
else: |
|
|
raise NotImplementedError(f"initialization method [{init_type}] is not implemented") |
|
|
if hasattr(m, "bias") and m.bias is not None: |
|
|
init.constant_(m.bias.data, 0.0) |
|
|
elif classname.find("BatchNorm2d") != -1: |
|
|
init.normal_(m.weight.data, 1.0, gain) |
|
|
init.constant_(m.bias.data, 0.0) |
|
|
|
|
|
print(f"initialize network with {init_type}") |
|
|
net.apply(init_func) |
|
|
|
|
|
|
|
|
class Trainer(BaseModel): |
|
|
def name(self): |
|
|
return "Trainer" |
|
|
|
|
|
def __init__(self, cfg: CONFIGCLASS): |
|
|
super().__init__(cfg) |
|
|
self.arch = cfg.arch |
|
|
self.model = get_network(self.arch, cfg.isTrain, cfg.continue_train, cfg.init_gain, cfg.pretrained) |
|
|
|
|
|
self.loss_fn = nn.BCEWithLogitsLoss() |
|
|
|
|
|
if cfg.optim == "adam": |
|
|
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=cfg.lr, betas=(cfg.beta1, 0.999)) |
|
|
elif cfg.optim == "sgd": |
|
|
self.optimizer = torch.optim.SGD(self.model.parameters(), lr=cfg.lr, momentum=0.9, weight_decay=5e-4) |
|
|
else: |
|
|
raise ValueError("optim should be [adam, sgd]") |
|
|
if cfg.warmup: |
|
|
scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR( |
|
|
self.optimizer, cfg.nepoch - cfg.warmup_epoch, eta_min=1e-6 |
|
|
) |
|
|
self.scheduler = GradualWarmupScheduler( |
|
|
self.optimizer, multiplier=1, total_epoch=cfg.warmup_epoch, after_scheduler=scheduler_cosine |
|
|
) |
|
|
self.scheduler.step() |
|
|
if cfg.continue_train: |
|
|
self.load_networks(cfg.epoch) |
|
|
self.model.to(self.device) |
|
|
|
|
|
|
|
|
|
|
|
def adjust_learning_rate(self, min_lr=1e-6): |
|
|
for param_group in self.optimizer.param_groups: |
|
|
param_group["lr"] /= 10.0 |
|
|
if param_group["lr"] < min_lr: |
|
|
return False |
|
|
return True |
|
|
|
|
|
def set_input(self, input): |
|
|
img, label, meta = input if len(input) == 3 else (input[0], input[1], {}) |
|
|
self.input = img.to(self.device) |
|
|
self.label = label.to(self.device).float() |
|
|
for k in meta.keys(): |
|
|
if isinstance(meta[k], torch.Tensor): |
|
|
meta[k] = meta[k].to(self.device) |
|
|
self.meta = meta |
|
|
|
|
|
def forward(self): |
|
|
self.output = self.model(self.input, self.meta) |
|
|
|
|
|
def get_loss(self): |
|
|
return self.loss_fn(self.output.squeeze(1), self.label) |
|
|
|
|
|
def optimize_parameters(self): |
|
|
self.forward() |
|
|
self.loss = self.loss_fn(self.output.squeeze(1), self.label) |
|
|
self.optimizer.zero_grad() |
|
|
self.loss.backward() |
|
|
self.optimizer.step() |
|
|
|