| from .. import devices | |
| def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None): | |
| hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None) | |
| if hypernetwork_layers is None: | |
| return context_k, context_v | |
| if layer is not None: | |
| layer.hyper_k = hypernetwork_layers[0] | |
| layer.hyper_v = hypernetwork_layers[1] | |
| context_k = devices.cond_cast_unet(hypernetwork_layers[0](devices.cond_cast_float(context_k))) | |
| context_v = devices.cond_cast_unet(hypernetwork_layers[1](devices.cond_cast_float(context_v))) | |
| return context_k, context_v | |
| def apply_hypernetworks(hypernetworks, context, layer=None): | |
| context_k = context | |
| context_v = context | |
| for hypernetwork in hypernetworks: | |
| context_k, context_v = apply_single_hypernetwork(hypernetwork, context_k, context_v, layer) | |
| return context_k, context_v |