import torch def wrapped_getattr(self, name, default=None, wrapped_member_name="model"): """should be called from wrappers of model classes such as ClassifierFreeSampleModel""" if isinstance(self, torch.nn.Module): # for descendants of nn.Module, name may be in self.__dict__[_parameters/_buffers/_modules] # so we activate nn.Module.__getattr__ first. # Otherwise, we might encounter an infinite loop try: attr = torch.nn.Module.__getattr__(self, name) except AttributeError: wrapped_member = torch.nn.Module.__getattr__(self, wrapped_member_name) attr = getattr(wrapped_member, name, default) else: # the easy case, where self is not derived from nn.Module wrapped_member = getattr(self, wrapped_member_name) attr = getattr(wrapped_member, name, default) return attr def load_model_wo_clip(model, state_dict): missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) assert len(unexpected_keys) == 0 assert all([k.startswith("clip_model.") for k in missing_keys])