Manoj Bhat commited on
Commit ·
79d63ff
1
Parent(s): 09252c5
updaating quant
Browse files- src/loss.py +8 -0
- src/pipeline.py +25 -23
src/loss.py
CHANGED
|
@@ -43,3 +43,11 @@ class SchedulerWrapper:
|
|
| 43 |
H=A[-1];I=torch.cat(C.catch_x_[H],dim=0);B.append(I);A=torch.tensor(A,dtype=torch.int32);B=torch.stack(B);D=torch.stack(D);return A,B,D
|
| 44 |
def load_loss_params(A):B,C,D=torch.load(A.loss_params_path,map_location='cpu');A.loss_model=LossSchedulerModel(C,D);A.loss_scheduler=LossScheduler(B,A.loss_model)
|
| 45 |
def prepare_loss(A,num_accelerate_steps=15):A.load_loss_params()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
H=A[-1];I=torch.cat(C.catch_x_[H],dim=0);B.append(I);A=torch.tensor(A,dtype=torch.int32);B=torch.stack(B);D=torch.stack(D);return A,B,D
|
| 44 |
def load_loss_params(A):B,C,D=torch.load(A.loss_params_path,map_location='cpu');A.loss_model=LossSchedulerModel(C,D);A.loss_scheduler=LossScheduler(B,A.loss_model)
|
| 45 |
def prepare_loss(A,num_accelerate_steps=15):A.load_loss_params()
|
| 46 |
+
class LoadSDXLQuantization:
|
| 47 |
+
def __init__(self, model, path='quantized_layers.pth', device='cpu'):
|
| 48 |
+
self.model = model
|
| 49 |
+
self.quantized_layers_state = torch.load(path, map_location=device)
|
| 50 |
+
def load_model(self):
|
| 51 |
+
for name, module in self.model.named_modules():
|
| 52 |
+
if name in self.quantized_layers_state:
|
| 53 |
+
module.load_state_dict(self.quantized_layers_state[name])
|
src/pipeline.py
CHANGED
|
@@ -1,16 +1,16 @@
|
|
| 1 |
import torch
|
| 2 |
from PIL.Image import Image
|
| 3 |
from diffusers import StableDiffusionXLPipeline
|
| 4 |
-
|
| 5 |
from pipelines.models import TextToImageRequest
|
| 6 |
from diffusers import DDIMScheduler
|
| 7 |
from torch import Generator
|
| 8 |
-
from loss import SchedulerWrapper
|
| 9 |
|
| 10 |
from onediffx import compile_pipe, save_pipe, load_pipe
|
| 11 |
|
| 12 |
def callback_dynamic_cfg(pipe, step_index, timestep, callback_kwargs):
|
| 13 |
-
if step_index == int(pipe.num_timesteps * 0.
|
| 14 |
callback_kwargs['prompt_embeds'] = callback_kwargs['prompt_embeds'].chunk(2)[-1]
|
| 15 |
callback_kwargs['add_text_embeds'] = callback_kwargs['add_text_embeds'].chunk(2)[-1]
|
| 16 |
callback_kwargs['add_time_ids'] = callback_kwargs['add_time_ids'].chunk(2)[-1]
|
|
@@ -23,34 +23,21 @@ def load_pipeline(pipeline=None) -> StableDiffusionXLPipeline:
|
|
| 23 |
pipeline = StableDiffusionXLPipeline.from_pretrained(
|
| 24 |
"stablediffusionapi/newdream-sdxl-20",
|
| 25 |
torch_dtype=torch.float16,
|
| 26 |
-
)
|
| 27 |
-
|
| 28 |
-
# Prune the individual models
|
| 29 |
-
for name, module in pipeline.text_encoder.named_modules():
|
| 30 |
-
if isinstance(module, torch.nn.Linear):
|
| 31 |
-
prune.l1_unstructured(module, 'weight', amount=0.2)
|
| 32 |
-
if isinstance(module, torch.nn.Embedding):
|
| 33 |
-
prune.l1_unstructured(module, 'weight', amount=0.2)
|
| 34 |
-
for name, module in pipeline.unet.named_modules():
|
| 35 |
-
if isinstance(module, torch.nn.Linear):
|
| 36 |
-
prune.l1_unstructured(module, 'weight', amount=0.2)
|
| 37 |
-
if isinstance(module, torch.nn.Embedding):
|
| 38 |
-
prune.l1_unstructured(module, 'weight', amount=0.2)
|
| 39 |
-
for name, module in pipeline.vae.named_modules():
|
| 40 |
-
if isinstance(module, torch.nn.Linear):
|
| 41 |
-
prune.l1_unstructured(module, 'weight', amount=0.2)
|
| 42 |
-
|
| 43 |
-
|
| 44 |
pipeline.scheduler = SchedulerWrapper(DDIMScheduler.from_config(pipeline.scheduler.config))
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
pipeline = compile_pipe(pipeline)
|
| 47 |
load_pipe(pipeline, dir="/home/sandbox/.cache/huggingface/hub/models--RobertML--cached-pipe-02/snapshots/58d70deae87034cce351b780b48841f9746d4ad7")
|
| 48 |
|
| 49 |
for _ in range(2):
|
| 50 |
-
deepcache_output = pipeline(prompt="
|
| 51 |
pipeline.scheduler.prepare_loss()
|
| 52 |
for _ in range(4):
|
| 53 |
-
pipeline(prompt="
|
| 54 |
return pipeline
|
| 55 |
|
| 56 |
def infer(request: TextToImageRequest, pipeline: StableDiffusionXLPipeline) -> Image:
|
|
@@ -77,3 +64,18 @@ def infer(request: TextToImageRequest, pipeline: StableDiffusionXLPipeline) -> I
|
|
| 77 |
).images[0]
|
| 78 |
|
| 79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
from PIL.Image import Image
|
| 3 |
from diffusers import StableDiffusionXLPipeline
|
| 4 |
+
|
| 5 |
from pipelines.models import TextToImageRequest
|
| 6 |
from diffusers import DDIMScheduler
|
| 7 |
from torch import Generator
|
| 8 |
+
from loss import SchedulerWrapper, LoadSDXLQuantization
|
| 9 |
|
| 10 |
from onediffx import compile_pipe, save_pipe, load_pipe
|
| 11 |
|
| 12 |
def callback_dynamic_cfg(pipe, step_index, timestep, callback_kwargs):
|
| 13 |
+
if step_index == int(pipe.num_timesteps * 0.78):
|
| 14 |
callback_kwargs['prompt_embeds'] = callback_kwargs['prompt_embeds'].chunk(2)[-1]
|
| 15 |
callback_kwargs['add_text_embeds'] = callback_kwargs['add_text_embeds'].chunk(2)[-1]
|
| 16 |
callback_kwargs['add_time_ids'] = callback_kwargs['add_time_ids'].chunk(2)[-1]
|
|
|
|
| 23 |
pipeline = StableDiffusionXLPipeline.from_pretrained(
|
| 24 |
"stablediffusionapi/newdream-sdxl-20",
|
| 25 |
torch_dtype=torch.float16,
|
| 26 |
+
)
|
| 27 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
pipeline.scheduler = SchedulerWrapper(DDIMScheduler.from_config(pipeline.scheduler.config))
|
| 29 |
+
quantizer = LoadSDXLQuantization(pipeline.unet)
|
| 30 |
+
quantizer.load_model()
|
| 31 |
+
pipeline.to("cuda")
|
| 32 |
|
| 33 |
pipeline = compile_pipe(pipeline)
|
| 34 |
load_pipe(pipeline, dir="/home/sandbox/.cache/huggingface/hub/models--RobertML--cached-pipe-02/snapshots/58d70deae87034cce351b780b48841f9746d4ad7")
|
| 35 |
|
| 36 |
for _ in range(2):
|
| 37 |
+
deepcache_output = pipeline(prompt="polypterid, fattenable, geoparallelotropic, Galeus, galipine, peritoneum, malappropriate, Sekar", output_type="pil", num_inference_steps=20)
|
| 38 |
pipeline.scheduler.prepare_loss()
|
| 39 |
for _ in range(4):
|
| 40 |
+
pipeline(prompt="polypterid, fattenable, geoparallelotropic, Galeus, galipine, peritoneum, malappropriate, Sekar", output_type="pil", num_inference_steps=20)
|
| 41 |
return pipeline
|
| 42 |
|
| 43 |
def infer(request: TextToImageRequest, pipeline: StableDiffusionXLPipeline) -> Image:
|
|
|
|
| 64 |
).images[0]
|
| 65 |
|
| 66 |
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
|