| import numpy as np | |
| import torch | |
| import safetensors | |
| from safetensors.torch import save_file | |
| import matplotlib.pyplot as plt | |
| model = safetensors.safe_open('sd3_medium_incl_clips_t5xxlfp16.safetensors', 'pt') | |
| keys = model.keys() | |
| dic = {key:model.get_tensor(key) for key in keys} | |
| parts = ['diffusion_model'] | |
| count = 0 | |
| for k in keys: | |
| if all(i in k for i in parts): | |
| v = dic[k] | |
| print(f'{k}: {v.std()}') | |
| dic[k] += torch.normal(torch.zeros_like(v)*v.mean(), torch.ones_like(v)*v.std()*.02) | |
| count += 1 | |
| print(count) | |
| save_file(dic, 'sd3_medium_incl_clips_t5xxlfp16.safetensors_perturbed3.safetensors', model.metadata()) |