File size: 4,849 Bytes
2651102 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 | """
工具函数
"""
import time
import math
from typing import Optional
from datetime import datetime
class Timer:
"""计时器,用于统计训练速度"""
def __init__(self):
self.start_time = None
self.elapsed = 0
self.count = 0
def start(self):
self.start_time = time.time()
def stop(self):
if self.start_time:
self.elapsed += time.time() - self.start_time
self.count += 1
self.start_time = None
def reset(self):
self.elapsed = 0
self.count = 0
self.start_time = None
@property
def avg_time(self) -> float:
if self.count == 0:
return 0
return self.elapsed / self.count
@property
def speed(self) -> float:
if self.elapsed == 0:
return 0
return self.count / self.elapsed
class ProgressTracker:
"""训练进度追踪器"""
def __init__(self, total_steps: int, desc: str = "Training"):
self.total_steps = total_steps
self.desc = desc
self.current_step = 0
self.start_time = time.time()
self.loss_history = []
@property
def elapsed(self) -> float:
"""已用时间"""
return time.time() - self.start_time
@property
def count(self) -> int:
"""已处理步数"""
return self.current_step
def update(self, step: int, loss: Optional[float] = None):
self.current_step = step
if loss is not None:
self.loss_history.append(loss)
def format_progress(self, current_loss: Optional[float] = None) -> str:
"""格式化进度显示"""
elapsed = time.time() - self.start_time
progress = self.current_step / self.total_steps
# 预计剩余时间
if progress > 0:
eta = elapsed / progress - elapsed
eta_str = self._format_time(eta)
else:
eta_str = "--:--:--"
# 速度
speed = self.current_step / elapsed if elapsed > 0 else 0
# 进度条
bar_len = 30
filled = int(bar_len * progress)
bar = "█" * filled + "░" * (bar_len - filled)
# 损失
loss_str = f"loss={current_loss:.4f}" if current_loss is not None else ""
return f"{self.desc}: |{bar}| {self.current_step}/{self.total_steps} [{self._format_time(elapsed)}<{eta_str}, {speed:.2f}it/s] {loss_str}"
@staticmethod
def _format_time(seconds: float) -> str:
if seconds < 0:
return "--:--:--"
hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60)
secs = int(seconds % 60)
return f"{hours:02d}:{minutes:02d}:{secs:02d}"
def cosine_similarity(a, b):
"""计算余弦相似度"""
import torch
return torch.nn.functional.cosine_similarity(a, b, dim=-1)
def count_parameters(model) -> int:
"""计算模型参数量"""
return sum(p.numel() for p in model.parameters() if p.requires_grad)
def format_number(n: int) -> str:
"""格式化数字,添加千分位"""
if n >= 1_000_000:
return f"{n/1_000_000:.1f}M"
elif n >= 1_000:
return f"{n/1_000:.1f}K"
return str(n)
def get_timestamp() -> str:
"""获取时间戳字符串"""
return datetime.now().strftime("%Y%m%d_%H%M%S")
def ensure_dir(path: str):
"""确保目录存在"""
import os
os.makedirs(path, exist_ok=True)
def save_checkpoint(model, optimizer, epoch: int, step: int, loss: float, path: str):
"""保存检查点"""
import torch
torch.save({
'epoch': epoch,
'step': step,
'loss': loss,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
}, path)
def load_checkpoint(model, optimizer, path: str):
"""加载检查点"""
import torch
checkpoint = torch.load(path, map_location='cpu', weights_only=False)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
return checkpoint['epoch'], checkpoint['step'], checkpoint['loss']
class EarlyStopping:
"""早停机制"""
def __init__(self, patience: int = 5, min_delta: float = 0.001):
self.patience = patience
self.min_delta = min_delta
self.counter = 0
self.best_loss = float('inf')
self.should_stop = False
def __call__(self, loss: float) -> bool:
if loss < self.best_loss - self.min_delta:
self.best_loss = loss
self.counter = 0
else:
self.counter += 1
if self.counter >= self.patience:
self.should_stop = True
return self.should_stop
|