| import torch
|
| import torch.distributed as dist
|
| from torch.distributions.uniform import Uniform
|
|
|
| import os
|
| import re
|
| import sys
|
| import math
|
| import logging
|
| from copy import deepcopy
|
| from collections import OrderedDict
|
| import random
|
| import numpy as np
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
|
| def randint(low, high):
|
| return int(torch.randint(low, high, (1, )))
|
|
|
|
|
| def rand_uniform(low, high):
|
| return float(Uniform(low, high).sample())
|
|
|
|
|
| def get_logger(out_dir):
|
| logger = logging.getLogger('Exp')
|
| logger.setLevel(logging.INFO)
|
| formatter = logging.Formatter("%(asctime)s %(levelname)s %(message)s")
|
|
|
| file_path = os.path.join(out_dir, "run.log")
|
| file_hdlr = logging.FileHandler(file_path)
|
| file_hdlr.setFormatter(formatter)
|
|
|
| strm_hdlr = logging.StreamHandler(sys.stdout)
|
| strm_hdlr.setFormatter(formatter)
|
|
|
| logger.addHandler(file_hdlr)
|
| logger.addHandler(strm_hdlr)
|
| return logger
|
|
|
|
|
| def update_lr_cos(nb_iter, warm_up_iter, total_iter, max_lr, optimizer, min_lr=1e-7):
|
|
|
| if nb_iter < warm_up_iter:
|
| current_lr = max_lr * (nb_iter + 1) / (warm_up_iter + 1)
|
| else:
|
| current_lr = min_lr + (max_lr - min_lr) * 0.5 * (1. + math.cos(math.pi * nb_iter / (total_iter - warm_up_iter)))
|
|
|
| for param_group in optimizer.param_groups:
|
| param_group["lr"] = current_lr
|
|
|
| return optimizer, current_lr
|
|
|
|
|
| class CTCLabelConverter(object):
|
| def __init__(self, character):
|
| dict_character = list(character)
|
| self.dict = {}
|
| for i, char in enumerate(dict_character):
|
| self.dict[char] = i + 1
|
| if len(self.dict) == 87:
|
| self.dict['['], self.dict[']'] = 88, 89
|
| self.character = ['[blank]'] + dict_character
|
|
|
| def encode(self, text):
|
| length = [len(s) for s in text]
|
| text = ''.join(text)
|
| text = [self.dict[char] for char in text]
|
|
|
| return (torch.IntTensor(text).to(device), torch.IntTensor(length).to(device))
|
|
|
| def decode(self, text_index, length):
|
| texts = []
|
| index = 0
|
|
|
| for l in length:
|
| t = text_index[index:index + l]
|
| char_list = []
|
| for i in range(l):
|
| if t[i] != 0 and (not (i > 0 and t[i - 1] == t[i])) and t[i]<len(self.character):
|
| char_list.append(self.character[t[i]])
|
| text = ''.join(char_list)
|
|
|
| texts.append(text)
|
| index += l
|
| return texts
|
|
|
|
|
| class Averager(object):
|
| def __init__(self):
|
| self.reset()
|
|
|
| def add(self, v):
|
| count = v.data.numel()
|
| v = v.data.sum()
|
| self.n_count += count
|
| self.sum += v
|
|
|
| def reset(self):
|
| self.n_count = 0
|
| self.sum = 0
|
|
|
| def val(self):
|
| res = 0
|
| if self.n_count != 0:
|
| res = self.sum / float(self.n_count)
|
| return res
|
|
|
|
|
| class Metric(object):
|
| def __init__(self, name=''):
|
| self.name = name
|
| self.sum = torch.tensor(0.).double()
|
| self.n = torch.tensor(0.)
|
|
|
| def update(self, val):
|
| rt = val.clone()
|
| dist.all_reduce(rt, op=dist.ReduceOp.SUM)
|
| rt /= dist.get_world_size()
|
| self.sum += rt.detach().cpu().double()
|
| self.n += 1
|
|
|
| @property
|
| def avg(self):
|
| return self.sum / self.n.double()
|
|
|
|
|
| class ModelEma:
|
| def __init__(self, model, decay=0.9999, device='', resume=''):
|
| self.ema = deepcopy(model)
|
| self.ema.eval()
|
| self.decay = decay
|
| self.device = device
|
| if device:
|
| self.ema.to(device=device)
|
| self.ema_has_module = hasattr(self.ema, 'module')
|
| if resume:
|
| self._load_checkpoint(resume)
|
| for p in self.ema.parameters():
|
| p.requires_grad_(False)
|
|
|
| def _load_checkpoint(self, checkpoint_path, mapl=None):
|
| checkpoint = torch.load(checkpoint_path,map_location=mapl)
|
| assert isinstance(checkpoint, dict)
|
| if 'state_dict_ema' in checkpoint:
|
| new_state_dict = OrderedDict()
|
| for k, v in checkpoint['state_dict_ema'].items():
|
| if self.ema_has_module:
|
| name = 'module.' + k if not k.startswith('module') else k
|
| else:
|
| name = k
|
| new_state_dict[name] = v
|
| self.ema.load_state_dict(new_state_dict)
|
| print("=> Loaded state_dict_ema")
|
| else:
|
| print("=> Failed to find state_dict_ema, starting from loaded model weights")
|
|
|
| def update(self, model, num_updates=-1):
|
| needs_module = hasattr(model, 'module') and not self.ema_has_module
|
| if num_updates >= 0:
|
| _cdecay = min(self.decay, (1 + num_updates) / (10 + num_updates))
|
| else:
|
| _cdecay = self.decay
|
|
|
| with torch.no_grad():
|
| msd = model.state_dict()
|
| for k, ema_v in self.ema.state_dict().items():
|
| if needs_module:
|
| k = 'module.' + k
|
| model_v = msd[k].detach()
|
| if self.device:
|
| model_v = model_v.to(device=self.device)
|
| ema_v.copy_(ema_v * _cdecay + (1. - _cdecay) * model_v)
|
|
|
|
|
| def format_string_for_wer(str):
|
| str = re.sub('([\[\]{}/\\()\"\'&+*=<>?.;:,!\-—_€#%°])', r' \1 ', str)
|
| str = re.sub('([ \n])+', " ", str).strip()
|
| return str
|
|
|
| def load_checkpoint(model, model_ema, optimizer, checkpoint_path, logger):
|
| best_cer, best_wer, start_iter = 1e+6, 1e+6, 1
|
| train_loss, train_loss_count = 0.0, 0
|
| optimizer_state = None
|
| if checkpoint_path is not None and os.path.isfile(checkpoint_path):
|
| logger.info(f"Resuming from checkpoint: {checkpoint_path}")
|
| checkpoint = torch.load(
|
| checkpoint_path, map_location='cpu', weights_only=False)
|
|
|
|
|
| model_dict = OrderedDict()
|
| pattern = re.compile('module.')
|
|
|
|
|
|
|
| if 'model' in checkpoint:
|
| source_dict = checkpoint['model']
|
| logger.info("Loading main model from 'model' state dict")
|
| elif 'state_dict_ema' in checkpoint:
|
| source_dict = checkpoint['state_dict_ema']
|
| logger.info(
|
| "Loading main model from 'state_dict_ema' (fallback)")
|
| else:
|
| raise KeyError(
|
| "Neither 'model' nor 'state_dict_ema' found in checkpoint")
|
|
|
| for k, v in source_dict.items():
|
| if re.search("module", k):
|
| model_dict[re.sub(pattern, '', k)] = v
|
| else:
|
| model_dict[k] = v
|
|
|
| model.load_state_dict(model_dict, strict=True)
|
| logger.info("Successfully loaded main model state dict")
|
|
|
|
|
| if 'state_dict_ema' in checkpoint and model_ema is not None:
|
| ema_dict = OrderedDict()
|
| for k, v in checkpoint['state_dict_ema'].items():
|
| if re.search("module", k):
|
| ema_dict[re.sub(pattern, '', k)] = v
|
| else:
|
| ema_dict[k] = v
|
| model_ema.ema.load_state_dict(ema_dict, strict=True)
|
| logger.info("Successfully loaded EMA model state dict")
|
|
|
|
|
| if 'optimizer' in checkpoint and optimizer is not None:
|
| try:
|
| optimizer_state = checkpoint['optimizer']
|
| logger.info(
|
| "Optimizer state will be loaded after optimizer initialization")
|
| except Exception as e:
|
| logger.warning(f"Failed to prepare optimizer state: {e}")
|
| optimizer_state = None
|
|
|
|
|
| if 'best_cer' in checkpoint:
|
| best_cer = checkpoint['best_cer']
|
| if 'best_wer' in checkpoint:
|
| best_wer = checkpoint['best_wer']
|
| if 'nb_iter' in checkpoint:
|
| start_iter = checkpoint['nb_iter'] + 1
|
|
|
|
|
| m = re.search(
|
| r'checkpoint_(?P<cer>[\d\.]+)_(?P<wer>[\d\.]+)_(?P<iter>\d+)\.pth', checkpoint_path)
|
| if m and 'best_cer' not in checkpoint:
|
| best_cer = float(m.group('cer'))
|
| best_wer = float(m.group('wer'))
|
| start_iter = int(m.group('iter')) + 1
|
|
|
| if 'train_loss' in checkpoint:
|
| train_loss = checkpoint['train_loss']
|
| if 'train_loss_count' in checkpoint:
|
| train_loss_count = checkpoint['train_loss_count']
|
| if 'random_state' in checkpoint:
|
| random.setstate(checkpoint['random_state'])
|
| logger.info("Restored random state")
|
| if 'numpy_state' in checkpoint:
|
| np.random.set_state(checkpoint['numpy_state'])
|
| logger.info("Restored numpy random state")
|
| if 'torch_state' in checkpoint:
|
| torch.set_rng_state(checkpoint['torch_state'])
|
| logger.info("Restored torch random state")
|
| if 'torch_cuda_state' in checkpoint and torch.cuda.is_available():
|
| torch.cuda.set_rng_state(checkpoint['torch_cuda_state'])
|
| logger.info("Restored torch cuda random state")
|
|
|
|
|
| total_params = sum(p.numel() for p in model.parameters())
|
| logger.info(f"Model loaded with {total_params} total parameters")
|
|
|
| logger.info(
|
| f"Resumed best_cer={best_cer}, best_wer={best_wer}, start_iter={start_iter}")
|
| return best_cer, best_wer, start_iter, optimizer_state, train_loss, train_loss_count |