Update src/loss.py
Browse files- src/loss.py +1 -1
src/loss.py
CHANGED
|
@@ -46,7 +46,7 @@ class SchedulerWrapper:
|
|
| 46 |
def prepare_loss(A,num_accelerate_steps=15):A.load_loss_params()
|
| 47 |
def get_instance(device):
|
| 48 |
compress = output.Decoder().to(device).requires_grad_(False)
|
| 49 |
-
compress.load_state_dict(torch.load("
|
| 50 |
return compress
|
| 51 |
@torch.no_grad()
|
| 52 |
def hook_pipe(pipe, compress, mul, sub, scaling_factor):
|
|
|
|
| 46 |
def prepare_loss(A,num_accelerate_steps=15):A.load_loss_params()
|
| 47 |
def get_instance(device):
|
| 48 |
compress = output.Decoder().to(device).requires_grad_(False)
|
| 49 |
+
compress.load_state_dict(torch.load("taesdxl_decoder.pth", map_location=device, weights_only=True)) # taesdxl_decoder.pth
|
| 50 |
return compress
|
| 51 |
@torch.no_grad()
|
| 52 |
def hook_pipe(pipe, compress, mul, sub, scaling_factor):
|