mdm / utils /misc.py
hassanjbara's picture
update model
5007d4b
import torch
def wrapped_getattr(self, name, default=None, wrapped_member_name="model"):
"""should be called from wrappers of model classes such as ClassifierFreeSampleModel"""
if isinstance(self, torch.nn.Module):
# for descendants of nn.Module, name may be in self.__dict__[_parameters/_buffers/_modules]
# so we activate nn.Module.__getattr__ first.
# Otherwise, we might encounter an infinite loop
try:
attr = torch.nn.Module.__getattr__(self, name)
except AttributeError:
wrapped_member = torch.nn.Module.__getattr__(self, wrapped_member_name)
attr = getattr(wrapped_member, name, default)
else:
# the easy case, where self is not derived from nn.Module
wrapped_member = getattr(self, wrapped_member_name)
attr = getattr(wrapped_member, name, default)
return attr
def load_model_wo_clip(model, state_dict):
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
assert len(unexpected_keys) == 0
assert all([k.startswith("clip_model.") for k in missing_keys])