gaparmar commited on
Commit
bb4cdae
·
1 Parent(s): 1930c69

4bit and benchmarked

Browse files
Files changed (2) hide show
  1. app.py +10 -1
  2. my_utils/group_inference.py +4 -1
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  import spaces
 
3
  import gradio as gr
4
  import torch
5
  import functools
@@ -8,11 +9,16 @@ import torch.nn.functional as F
8
  from diffusers import FluxPipeline, AutoencoderTiny
9
  from transformers import CLIPProcessor, CLIPModel, AutoModel
10
  from transformers.models.clip.modeling_clip import _get_vector_norm
 
 
11
  from my_utils.group_inference import run_group_inference
12
  from my_utils.default_values import apply_defaults
 
13
  import argparse
14
 
15
- pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell").to("cuda")
 
 
16
  pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taef1").to("cuda")
17
 
18
  m_clip = CLIPModel.from_pretrained("multimodalart/clip-vit-base-patch32").to("cuda")
@@ -178,7 +184,10 @@ def generate_images(prompt, starting_candidates, output_group_size, pruning_rati
178
  }
179
  print(f"pruning ratio is: {pruning_ratio}")
180
  # Run group inference
 
181
  output_group = run_group_inference(pipe, **inference_args)
 
 
182
  return output_group
183
 
184
 
 
1
  import os
2
  import spaces
3
+ import time
4
  import gradio as gr
5
  import torch
6
  import functools
 
9
  from diffusers import FluxPipeline, AutoencoderTiny
10
  from transformers import CLIPProcessor, CLIPModel, AutoModel
11
  from transformers.models.clip.modeling_clip import _get_vector_norm
12
+ from nunchaku import NunchakuFluxTransformer2dModel
13
+ from nunchaku.utils import get_precision
14
  from my_utils.group_inference import run_group_inference
15
  from my_utils.default_values import apply_defaults
16
+
17
  import argparse
18
 
19
+ precision = get_precision()
20
+ transformer = NunchakuFluxTransformer2dModel.from_pretrained(f"nunchaku-tech/nunchaku-flux.1-schnell/svdq-{precision}_r32-flux.1-schnell.safetensors")
21
+ pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", transformer=transformer, torch_dtype=torch.bfloat16).to("cuda")
22
  pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taef1").to("cuda")
23
 
24
  m_clip = CLIPModel.from_pretrained("multimodalart/clip-vit-base-patch32").to("cuda")
 
184
  }
185
  print(f"pruning ratio is: {pruning_ratio}")
186
  # Run group inference
187
+ t_start = time.time()
188
  output_group = run_group_inference(pipe, **inference_args)
189
+ t_end = time.time()
190
+ print(f"Time taken for group inference: {t_end - t_start} seconds")
191
  return output_group
192
 
193
 
my_utils/group_inference.py CHANGED
@@ -1,4 +1,4 @@
1
- import os, sys
2
  import math
3
  import torch
4
  import spaces
@@ -221,7 +221,10 @@ def run_group_inference(pipe, model_name=None, prompt=None, prompt_2=None, negat
221
  l_unary_scores = unary_score_fn(l_x0_decoded, target_caption=prompt)
222
  M_binary_scores = binary_score_fn(l_x0_decoded) # upper triangular matrix
223
  # run with Quadratic Integer Programming sover
 
224
  selected_indices = gurobi_solver(l_unary_scores, M_binary_scores, next_size, lam=lambda_score)
 
 
225
  l_latents = [next_latents[_i] for _i in selected_indices]
226
  else:
227
  l_latents = next_latents
 
1
+ import os, sys, time
2
  import math
3
  import torch
4
  import spaces
 
221
  l_unary_scores = unary_score_fn(l_x0_decoded, target_caption=prompt)
222
  M_binary_scores = binary_score_fn(l_x0_decoded) # upper triangular matrix
223
  # run with Quadratic Integer Programming sover
224
+ t_start = time.time()
225
  selected_indices = gurobi_solver(l_unary_scores, M_binary_scores, next_size, lam=lambda_score)
226
+ t_end = time.time()
227
+ print(f"Time taken for QIP: {t_end - t_start} seconds")
228
  l_latents = [next_latents[_i] for _i in selected_indices]
229
  else:
230
  l_latents = next_latents