SLIM-Brain / utils /utils.py
OneMore1's picture
Upload 12 files
538668e verified
import torch
import datetime
import time
import torch.distributed as dist
import yaml
import os
class MetricLogger:
"""Metric logger for training"""
def __init__(self, delimiter="\t"):
self.meters = {}
self.delimiter = delimiter
def update(self, **kwargs):
for k, v in kwargs.items():
if isinstance(v, torch.Tensor):
v = v.item()
if k not in self.meters:
self.meters[k] = SmoothedValue()
self.meters[k].update(v)
def __str__(self):
loss_str = []
for name, meter in self.meters.items():
loss_str.append(f"{name}: {meter}")
return self.delimiter.join(loss_str)
def synchronize_between_processes(self):
for meter in self.meters.values():
meter.synchronize_between_processes()
def log_every(self, iterable, print_freq, header=None):
i = 0
if not header:
header = ''
start_time = time.time()
end = time.time()
iter_time = SmoothedValue(fmt='{avg:.4f}')
data_time = SmoothedValue(fmt='{avg:.4f}')
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
log_msg = [
header,
'[{0' + space_fmt + '}/{1}]',
'eta: {eta}',
'{meters}',
'time: {time}',
'data: {data}'
]
log_msg = self.delimiter.join(log_msg)
for obj in iterable:
data_time.update(time.time() - end)
yield obj
iter_time.update(time.time() - end)
if i % print_freq == 0 or i == len(iterable) - 1:
eta_seconds = iter_time.global_avg * (len(iterable) - i)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
if torch.cuda.is_available() and dist.get_rank() == 0:
print(log_msg.format(
i, len(iterable), eta=eta_string,
meters=str(self),
time=str(iter_time), data=str(data_time)))
i += 1
end = time.time()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print(f'{header} Total time: {total_time_str} ({total_time / len(iterable):.4f} s / it)')
class SmoothedValue:
"""Track a series of values and provide access to smoothed values"""
def __init__(self, window_size=20, fmt=None):
if fmt is None:
fmt = "{median:.4f} ({global_avg:.4f})"
self.deque = []
self.total = 0.0
self.count = 0
self.fmt = fmt
self.window_size = window_size
def update(self, value, n=1):
self.deque.append(value)
if len(self.deque) > self.window_size:
self.deque.pop(0)
self.count += n
self.total += value * n
def synchronize_between_processes(self):
"""Synchronize across all processes"""
if not dist.is_available() or not dist.is_initialized():
return
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
dist.barrier()
dist.all_reduce(t)
t = t.tolist()
self.count = int(t[0])
self.total = t[1]
@property
def median(self):
d = sorted(self.deque)
n = len(d)
if n == 0:
return 0
if n % 2 == 0:
return (d[n // 2 - 1] + d[n // 2]) / 2
return d[n // 2]
@property
def avg(self):
if len(self.deque) == 0:
return 0
return sum(self.deque) / len(self.deque)
@property
def global_avg(self):
if self.count == 0:
return 0
return self.total / self.count
def __str__(self):
return self.fmt.format(
median=self.median,
avg=self.avg,
global_avg=self.global_avg,
max=max(self.deque) if len(self.deque) > 0 else 0,
value=self.deque[-1] if len(self.deque) > 0 else 0
)
def load_config(config_path):
"""Load configuration from YAML file"""
with open(config_path, 'r') as f:
config = yaml.safe_load(f)
return config
def log_to_file(log_file, message):
"""Write message to log file"""
if log_file is not None:
with open(log_file, 'a') as f:
timestamp = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
f.write(f"[{timestamp}] {message}\n")
f.flush()
def count_parameters(model, verbose=True):
"""Count model parameters"""
def count_params(module):
return sum(p.numel() for p in module.parameters() if p.requires_grad)
def format_number(num):
if num >= 1e9:
return f"{num/1e9:.2f}B"
elif num >= 1e6:
return f"{num/1e6:.2f}M"
elif num >= 1e3:
return f"{num/1e3:.2f}K"
else:
return str(num)
# If DDP model, get original model
if hasattr(model, 'module'):
model = model.module
total_params = count_params(model)
if verbose:
print("\n" + "="*80)
print("Model Parameter Statistics")
print("="*80)
# Count encoder parameters
encoder_params = 0
for name in ['patch_embed', 'blocks', 'encoder_norm']:
if hasattr(model, name):
module = getattr(model, name)
params = count_params(module)
encoder_params += params
print(f"{name:.<35} {params:>15,} ({format_number(params):>8})")
# Count head parameters
if hasattr(model, 'head'):
head_params = count_params(model.head)
print(f"{'Classification/Regression Head':.<35} {head_params:>15,} ({format_number(head_params):>8})")
print("\n" + "="*80)
print(f"{'Encoder Parameters':.<35} {encoder_params:>15,} ({format_number(encoder_params):>8})")
print(f"{'TOTAL TRAINABLE PARAMETERS':.<35} {total_params:>15,} ({format_number(total_params):>8})")
print("="*80 + "\n")
return total_params
def save_checkpoint(state, is_best, checkpoint_dir, filename='checkpoint.pth'):
"""Save checkpoint"""
checkpoint_path = os.path.join(checkpoint_dir, filename)
torch.save(state, checkpoint_path)
if is_best:
best_path = os.path.join(checkpoint_dir, 'checkpoint_best.pth')
torch.save(state, best_path)
def load_checkpoint(checkpoint_path, model, optimizer, scheduler, scaler=None):
"""Load checkpoint"""
if not os.path.isfile(checkpoint_path):
print(f"No checkpoint found at '{checkpoint_path}'")
return 0, 0.0, 0.0
print(f"Loading checkpoint '{checkpoint_path}'")
checkpoint = torch.load(checkpoint_path, map_location='cpu')
start_epoch = checkpoint['epoch']
best_metric = checkpoint.get('best_metric', 0.0)
best_loss = checkpoint.get('best_loss', float('inf'))
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
if scaler is not None and 'scaler_state_dict' in checkpoint:
scaler.load_state_dict(checkpoint['scaler_state_dict'])
print(f"Loaded checkpoint from epoch {start_epoch}")
return start_epoch, best_metric, best_loss
class LabelScaler:
def __init__(self, mean, std):
self.mean = mean
self.std = std
def transform(self, labels):
"""标准化: (y - mean) / std"""
return (labels - self.mean) / self.std