import torch import lightning.pytorch.loggers.wandb as wandb setattr(wandb, '_WANDB_AVAILABLE', True) torch.set_float32_matmul_precision('medium') # torch.compile / dynamo fallback: don't kill training when dynamo fails to # trace a graph (e.g. symbolic shape mismatches in compiled HF text encoders); # just run that node in eager mode instead. torch._dynamo.config.suppress_errors = True # also try not to recompile forever when shapes keep changing on us torch._dynamo.config.cache_size_limit = 64 import logging logger = logging.getLogger("wandb") logger.setLevel(logging.WARNING) import os os.environ["NCCL_DEBUG"] = "WARN" os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' import warnings warnings.simplefilter(action='ignore', category=FutureWarning) warnings.simplefilter(action='ignore', category=UserWarning)