| import os |
| import random |
| import numpy as np |
| import torch |
| import math |
| import time |
| import datetime |
| import json |
| from json import encoder |
|
|
|
|
| FORMAT_INFO = { |
| "inchi": { |
| "name": "InChI_text", |
| "tokenizer": "tokenizer_inchi.json", |
| "max_len": 300 |
| }, |
| "atomtok": { |
| "name": "SMILES_atomtok", |
| "tokenizer": "tokenizer_smiles_atomtok.json", |
| "max_len": 256 |
| }, |
| "nodes": {"max_len": 384}, |
| "atomtok_coords": {"max_len": 480}, |
| "chartok_coords": {"max_len": 480} |
| } |
|
|
|
|
| def init_logger(log_file='train.log'): |
| from logging import getLogger, INFO, FileHandler, Formatter, StreamHandler |
| logger = getLogger(__name__) |
| logger.setLevel(INFO) |
| handler1 = StreamHandler() |
| handler1.setFormatter(Formatter("%(message)s")) |
| handler2 = FileHandler(filename=log_file) |
| handler2.setFormatter(Formatter("%(message)s")) |
| logger.addHandler(handler1) |
| logger.addHandler(handler2) |
| return logger |
|
|
|
|
| def init_summary_writer(save_path): |
| from tensorboardX import SummaryWriter |
| summary = SummaryWriter(save_path) |
| return summary |
|
|
|
|
| def save_args(args): |
| dt = datetime.datetime.strftime(datetime.datetime.now(), "%y%m%d-%H%M") |
| path = os.path.join(args.save_path, f'train_{dt}.log') |
| with open(path, 'w') as f: |
| for k, v in vars(args).items(): |
| f.write(f"**** {k} = *{v}*\n") |
| return |
|
|
|
|
| def seed_torch(seed=42): |
| random.seed(seed) |
| os.environ['PYTHONHASHSEED'] = str(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| torch.cuda.manual_seed(seed) |
| torch.backends.cudnn.deterministic = True |
|
|
|
|
| class AverageMeter(object): |
| """Computes and stores the average and current value""" |
| def __init__(self): |
| self.reset() |
|
|
| def reset(self): |
| self.val = 0 |
| self.avg = 0 |
| self.sum = 0 |
| self.count = 0 |
|
|
| def update(self, val, n=1): |
| self.val = val |
| self.sum += val * n |
| self.count += n |
| self.avg = self.sum / self.count |
|
|
|
|
| class EpochMeter(AverageMeter): |
| def __init__(self): |
| super().__init__() |
| self.epoch = AverageMeter() |
|
|
| def update(self, val, n=1): |
| super().update(val, n) |
| self.epoch.update(val, n) |
|
|
|
|
| class LossMeter(EpochMeter): |
| def __init__(self): |
| self.subs = {} |
| super().__init__() |
|
|
| def reset(self): |
| super().reset() |
| for k in self.subs: |
| self.subs[k].reset() |
|
|
| def update(self, loss, losses, n=1): |
| loss = loss.item() |
| super().update(loss, n) |
| losses = {k: v.item() for k, v in losses.items()} |
| for k, v in losses.items(): |
| if k not in self.subs: |
| self.subs[k] = EpochMeter() |
| self.subs[k].update(v, n) |
|
|
|
|
| def asMinutes(s): |
| m = math.floor(s / 60) |
| s -= m * 60 |
| return '%dm %ds' % (m, s) |
|
|
|
|
| def timeSince(since, percent): |
| now = time.time() |
| s = now - since |
| es = s / (percent) |
| rs = es - s |
| return '%s (remain %s)' % (asMinutes(s), asMinutes(rs)) |
|
|
|
|
| def print_rank_0(message): |
| if torch.distributed.is_initialized(): |
| if torch.distributed.get_rank() == 0: |
| print(message, flush=True) |
| else: |
| print(message, flush=True) |
|
|
|
|
| def to_device(data, device): |
| if torch.is_tensor(data): |
| return data.to(device) |
| if type(data) is list: |
| return [to_device(v, device) for v in data] |
| if type(data) is dict: |
| return {k: to_device(v, device) for k, v in data.items()} |
|
|
|
|
| def round_floats(o): |
| if isinstance(o, float): |
| return round(o, 3) |
| if isinstance(o, dict): |
| return {k: round_floats(v) for k, v in o.items()} |
| if isinstance(o, (list, tuple)): |
| return [round_floats(x) for x in o] |
| return o |
|
|
|
|
| def format_df(df): |
| def _dumps(obj): |
| if obj is None: |
| return obj |
| return json.dumps(round_floats(obj)).replace(" ", "") |
| for field in ['node_coords', 'node_symbols', 'edges']: |
| if field in df.columns: |
| df[field] = [_dumps(obj) for obj in df[field]] |
| return df |
|
|