Spaces:
Sleeping
Sleeping
| import torch | |
| def wrapped_getattr(self, name, default=None, wrapped_member_name="model"): | |
| """should be called from wrappers of model classes such as ClassifierFreeSampleModel""" | |
| if isinstance(self, torch.nn.Module): | |
| # for descendants of nn.Module, name may be in self.__dict__[_parameters/_buffers/_modules] | |
| # so we activate nn.Module.__getattr__ first. | |
| # Otherwise, we might encounter an infinite loop | |
| try: | |
| attr = torch.nn.Module.__getattr__(self, name) | |
| except AttributeError: | |
| wrapped_member = torch.nn.Module.__getattr__(self, wrapped_member_name) | |
| attr = getattr(wrapped_member, name, default) | |
| else: | |
| # the easy case, where self is not derived from nn.Module | |
| wrapped_member = getattr(self, wrapped_member_name) | |
| attr = getattr(wrapped_member, name, default) | |
| return attr | |
| def load_model_wo_clip(model, state_dict): | |
| missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) | |
| assert len(unexpected_keys) == 0 | |
| assert all([k.startswith("clip_model.") for k in missing_keys]) | |