Manoj Bhat commited on
Commit
79d63ff
·
1 Parent(s): 09252c5

updaating quant

Browse files
Files changed (2) hide show
  1. src/loss.py +8 -0
  2. src/pipeline.py +25 -23
src/loss.py CHANGED
@@ -43,3 +43,11 @@ class SchedulerWrapper:
43
  H=A[-1];I=torch.cat(C.catch_x_[H],dim=0);B.append(I);A=torch.tensor(A,dtype=torch.int32);B=torch.stack(B);D=torch.stack(D);return A,B,D
44
  def load_loss_params(A):B,C,D=torch.load(A.loss_params_path,map_location='cpu');A.loss_model=LossSchedulerModel(C,D);A.loss_scheduler=LossScheduler(B,A.loss_model)
45
  def prepare_loss(A,num_accelerate_steps=15):A.load_loss_params()
 
 
 
 
 
 
 
 
 
43
  H=A[-1];I=torch.cat(C.catch_x_[H],dim=0);B.append(I);A=torch.tensor(A,dtype=torch.int32);B=torch.stack(B);D=torch.stack(D);return A,B,D
44
  def load_loss_params(A):B,C,D=torch.load(A.loss_params_path,map_location='cpu');A.loss_model=LossSchedulerModel(C,D);A.loss_scheduler=LossScheduler(B,A.loss_model)
45
  def prepare_loss(A,num_accelerate_steps=15):A.load_loss_params()
46
+ class LoadSDXLQuantization:
47
+ def __init__(self, model, path='quantized_layers.pth', device='cpu'):
48
+ self.model = model
49
+ self.quantized_layers_state = torch.load(path, map_location=device)
50
+ def load_model(self):
51
+ for name, module in self.model.named_modules():
52
+ if name in self.quantized_layers_state:
53
+ module.load_state_dict(self.quantized_layers_state[name])
src/pipeline.py CHANGED
@@ -1,16 +1,16 @@
1
  import torch
2
  from PIL.Image import Image
3
  from diffusers import StableDiffusionXLPipeline
4
- import torch.nn.utils.prune as prune
5
  from pipelines.models import TextToImageRequest
6
  from diffusers import DDIMScheduler
7
  from torch import Generator
8
- from loss import SchedulerWrapper
9
 
10
  from onediffx import compile_pipe, save_pipe, load_pipe
11
 
12
  def callback_dynamic_cfg(pipe, step_index, timestep, callback_kwargs):
13
- if step_index == int(pipe.num_timesteps * 0.77):
14
  callback_kwargs['prompt_embeds'] = callback_kwargs['prompt_embeds'].chunk(2)[-1]
15
  callback_kwargs['add_text_embeds'] = callback_kwargs['add_text_embeds'].chunk(2)[-1]
16
  callback_kwargs['add_time_ids'] = callback_kwargs['add_time_ids'].chunk(2)[-1]
@@ -23,34 +23,21 @@ def load_pipeline(pipeline=None) -> StableDiffusionXLPipeline:
23
  pipeline = StableDiffusionXLPipeline.from_pretrained(
24
  "stablediffusionapi/newdream-sdxl-20",
25
  torch_dtype=torch.float16,
26
- ).to("cuda")
27
-
28
- # Prune the individual models
29
- for name, module in pipeline.text_encoder.named_modules():
30
- if isinstance(module, torch.nn.Linear):
31
- prune.l1_unstructured(module, 'weight', amount=0.2)
32
- if isinstance(module, torch.nn.Embedding):
33
- prune.l1_unstructured(module, 'weight', amount=0.2)
34
- for name, module in pipeline.unet.named_modules():
35
- if isinstance(module, torch.nn.Linear):
36
- prune.l1_unstructured(module, 'weight', amount=0.2)
37
- if isinstance(module, torch.nn.Embedding):
38
- prune.l1_unstructured(module, 'weight', amount=0.2)
39
- for name, module in pipeline.vae.named_modules():
40
- if isinstance(module, torch.nn.Linear):
41
- prune.l1_unstructured(module, 'weight', amount=0.2)
42
-
43
-
44
  pipeline.scheduler = SchedulerWrapper(DDIMScheduler.from_config(pipeline.scheduler.config))
 
 
 
45
 
46
  pipeline = compile_pipe(pipeline)
47
  load_pipe(pipeline, dir="/home/sandbox/.cache/huggingface/hub/models--RobertML--cached-pipe-02/snapshots/58d70deae87034cce351b780b48841f9746d4ad7")
48
 
49
  for _ in range(2):
50
- deepcache_output = pipeline(prompt="telestereography, unstrengthen, preadministrator, copatroness, hyperpersonal, paramountness, paranoid, guaniferous", output_type="pil", num_inference_steps=20)
51
  pipeline.scheduler.prepare_loss()
52
  for _ in range(4):
53
- pipeline(prompt="telestereography, unstrengthen, preadministrator, copatroness, hyperpersonal, paramountness, paranoid, guaniferous", output_type="pil", num_inference_steps=20)
54
  return pipeline
55
 
56
  def infer(request: TextToImageRequest, pipeline: StableDiffusionXLPipeline) -> Image:
@@ -77,3 +64,18 @@ def infer(request: TextToImageRequest, pipeline: StableDiffusionXLPipeline) -> I
77
  ).images[0]
78
 
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  from PIL.Image import Image
3
  from diffusers import StableDiffusionXLPipeline
4
+
5
  from pipelines.models import TextToImageRequest
6
  from diffusers import DDIMScheduler
7
  from torch import Generator
8
+ from loss import SchedulerWrapper, LoadSDXLQuantization
9
 
10
  from onediffx import compile_pipe, save_pipe, load_pipe
11
 
12
  def callback_dynamic_cfg(pipe, step_index, timestep, callback_kwargs):
13
+ if step_index == int(pipe.num_timesteps * 0.78):
14
  callback_kwargs['prompt_embeds'] = callback_kwargs['prompt_embeds'].chunk(2)[-1]
15
  callback_kwargs['add_text_embeds'] = callback_kwargs['add_text_embeds'].chunk(2)[-1]
16
  callback_kwargs['add_time_ids'] = callback_kwargs['add_time_ids'].chunk(2)[-1]
 
23
  pipeline = StableDiffusionXLPipeline.from_pretrained(
24
  "stablediffusionapi/newdream-sdxl-20",
25
  torch_dtype=torch.float16,
26
+ )
27
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  pipeline.scheduler = SchedulerWrapper(DDIMScheduler.from_config(pipeline.scheduler.config))
29
+ quantizer = LoadSDXLQuantization(pipeline.unet)
30
+ quantizer.load_model()
31
+ pipeline.to("cuda")
32
 
33
  pipeline = compile_pipe(pipeline)
34
  load_pipe(pipeline, dir="/home/sandbox/.cache/huggingface/hub/models--RobertML--cached-pipe-02/snapshots/58d70deae87034cce351b780b48841f9746d4ad7")
35
 
36
  for _ in range(2):
37
+ deepcache_output = pipeline(prompt="polypterid, fattenable, geoparallelotropic, Galeus, galipine, peritoneum, malappropriate, Sekar", output_type="pil", num_inference_steps=20)
38
  pipeline.scheduler.prepare_loss()
39
  for _ in range(4):
40
+ pipeline(prompt="polypterid, fattenable, geoparallelotropic, Galeus, galipine, peritoneum, malappropriate, Sekar", output_type="pil", num_inference_steps=20)
41
  return pipeline
42
 
43
  def infer(request: TextToImageRequest, pipeline: StableDiffusionXLPipeline) -> Image:
 
64
  ).images[0]
65
 
66
 
67
+
68
+
69
+
70
+
71
+
72
+
73
+
74
+
75
+
76
+
77
+
78
+
79
+
80
+
81
+