aigv / core /utils1 /trainer.py
Qafig's picture
Upload folder using huggingface_hub
73e19ac verified
import os
import torch
import torch.nn as nn
from torch.nn import init
from utils1.config import CONFIGCLASS
from utils1.utils import get_network
from utils1.warmup import GradualWarmupScheduler
class BaseModel(nn.Module):
def __init__(self, cfg: CONFIGCLASS):
super().__init__()
self.cfg = cfg
self.total_steps = 0
self.isTrain = cfg.isTrain
self.save_dir = cfg.ckpt_dir
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
self.model:nn.Module
self.model=nn.Module.to(self.device)
# self.model.to(self.device)
#self.model.load_state_dict(torch.load('./checkpoints/optical.pth'))
self.optimizer: torch.optim.Optimizer
def save_networks(self, epoch: int):
save_filename = f"model_epoch_{epoch}.pth"
save_path = os.path.join(self.save_dir, save_filename)
# serialize model and optimizer to dict
state_dict = {
"model": self.model.state_dict(),
"optimizer": self.optimizer.state_dict(),
"total_steps": self.total_steps,
}
torch.save(state_dict, save_path)
# load models from the disk
def load_networks(self, epoch: int):
load_filename = f"model_epoch_{epoch}.pth"
load_path = os.path.join(self.save_dir, load_filename)
if epoch==0:
# load_filename = f"lsun_adm.pth"
load_path="checkpoints/optical.pth"
print("loading optical path")
else :
print(f"loading the model from {load_path}")
# print(f"loading the model from {load_path}")
# if you are using PyTorch newer than 0.4 (e.g., built from
# GitHub source), you can remove str() on self.device
state_dict = torch.load(load_path, map_location=self.device)
if hasattr(state_dict, "_metadata"):
del state_dict._metadata
self.model.load_state_dict(state_dict["model"])
self.total_steps = state_dict["total_steps"]
if self.isTrain and not self.cfg.new_optim:
self.optimizer.load_state_dict(state_dict["optimizer"])
# move optimizer state to GPU
for state in self.optimizer.state.values():
for k, v in state.items():
if torch.is_tensor(v):
state[k] = v.to(self.device)
for g in self.optimizer.param_groups:
g["lr"] = self.cfg.lr
def eval(self):
self.model.eval()
def test(self):
with torch.no_grad():
self.forward()
def init_weights(net: nn.Module, init_type="normal", gain=0.02):
def init_func(m: nn.Module):
classname = m.__class__.__name__
if hasattr(m, "weight") and (classname.find("Conv") != -1 or classname.find("Linear") != -1):
if init_type == "normal":
init.normal_(m.weight.data, 0.0, gain)
elif init_type == "xavier":
init.xavier_normal_(m.weight.data, gain=gain)
elif init_type == "kaiming":
init.kaiming_normal_(m.weight.data, a=0, mode="fan_in")
elif init_type == "orthogonal":
init.orthogonal_(m.weight.data, gain=gain)
else:
raise NotImplementedError(f"initialization method [{init_type}] is not implemented")
if hasattr(m, "bias") and m.bias is not None:
init.constant_(m.bias.data, 0.0)
elif classname.find("BatchNorm2d") != -1:
init.normal_(m.weight.data, 1.0, gain)
init.constant_(m.bias.data, 0.0)
print(f"initialize network with {init_type}")
net.apply(init_func)
class Trainer(BaseModel):
def name(self):
return "Trainer"
def __init__(self, cfg: CONFIGCLASS):
super().__init__(cfg)
self.arch = cfg.arch
self.model = get_network(self.arch, cfg.isTrain, cfg.continue_train, cfg.init_gain, cfg.pretrained)
self.loss_fn = nn.BCEWithLogitsLoss()
# initialize optimizers
if cfg.optim == "adam":
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=cfg.lr, betas=(cfg.beta1, 0.999))
elif cfg.optim == "sgd":
self.optimizer = torch.optim.SGD(self.model.parameters(), lr=cfg.lr, momentum=0.9, weight_decay=5e-4)
else:
raise ValueError("optim should be [adam, sgd]")
if cfg.warmup:
scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(
self.optimizer, cfg.nepoch - cfg.warmup_epoch, eta_min=1e-6
)
self.scheduler = GradualWarmupScheduler(
self.optimizer, multiplier=1, total_epoch=cfg.warmup_epoch, after_scheduler=scheduler_cosine
)
self.scheduler.step()
if cfg.continue_train:
self.load_networks(cfg.epoch)
self.model.to(self.device)
def adjust_learning_rate(self, min_lr=1e-6):
for param_group in self.optimizer.param_groups:
param_group["lr"] /= 10.0
if param_group["lr"] < min_lr:
return False
return True
def set_input(self, input):
img, label, meta = input if len(input) == 3 else (input[0], input[1], {})
self.input = img.to(self.device)
self.label = label.to(self.device).float()
for k in meta.keys():
if isinstance(meta[k], torch.Tensor):
meta[k] = meta[k].to(self.device)
self.meta = meta
def forward(self):
self.output = self.model(self.input, self.meta)
def get_loss(self):
return self.loss_fn(self.output.squeeze(1), self.label)
def optimize_parameters(self):
self.forward()
self.loss = self.loss_fn(self.output.squeeze(1), self.label)
self.optimizer.zero_grad()
self.loss.backward()
self.optimizer.step()