| ''' | |
| This script extracts the spade and structcond module from the official stablesr_000117.ckpt | |
| ''' | |
| import torch | |
| stablesr_path = 'models/stablesr_000117.ckpt' | |
| with open(stablesr_path, 'rb') as f: | |
| stablesr_ckpt = torch.load(f, map_location='cpu') | |
| srmodule = {} | |
| for k, v in stablesr_ckpt['state_dict'].items(): | |
| if 'spade' in k or 'structcond' in k: | |
| srmodule[k] = v | |
| # print(k) | |
| # save | |
| torch.save(srmodule, 'models/stablesr_sd21.ckpt') |