RishubhPar commited on
Commit
1d721be
·
verified ·
1 Parent(s): 35511a3

updated the app.

Browse files
Files changed (1) hide show
  1. app.py +10 -6
app.py CHANGED
@@ -31,12 +31,16 @@ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
31
  print("Using device:", DEVICE)
32
  torch.backends.cudnn.benchmark = True
33
 
 
 
34
  # -----------------------------
35
  # Model / pipeline loading
36
  # -----------------------------
37
  @torch.no_grad()
38
  @spaces.GPU
39
  def load_pipeline_single_gpu() -> FluxKontextSliderPipeline:
 
 
40
  pretrained = "black-forest-labs/FLUX.1-Kontext-dev"
41
 
42
  n_slider_layers = 4
@@ -78,7 +82,7 @@ def load_pipeline_single_gpu() -> FluxKontextSliderPipeline:
78
  # ------------------------------- --------------------- --------------------------- #
79
 
80
  # Build full pipeline on CPU; no device_map sharding
81
- pipeline = FluxKontextSliderPipeline.from_pretrained(
82
  pretrained,
83
  transformer=transformer,
84
  slider_projector=slider_projector,
@@ -89,12 +93,12 @@ def load_pipeline_single_gpu() -> FluxKontextSliderPipeline:
89
 
90
  print("loading the pipeline lora weights from: {}".format(trained_models_path))
91
 
92
- pipeline.load_lora_weights(trained_models_path)
93
  print("loaded the pipeline with lora weights from: {}".format(trained_models_path))
94
- return pipeline
95
-
96
 
97
- PIPELINE = load_pipeline_single_gpu()
 
 
98
  PIPELINE.to(DEVICE)
99
  print(f"[init] Pipeline loaded on {DEVICE}")
100
 
@@ -294,8 +298,8 @@ def _encode_prompt(prompt: str):
294
  # -----------------------------
295
  # Inference functions
296
  # -----------------------------
297
- @torch.no_grad()
298
  @spaces.GPU
 
299
  def generate_image_stack_edits(text_prompt, n_edits, input_image):
300
  """
301
  Compute n_edits images on a single GPU for slider values in (0,1],
 
31
  print("Using device:", DEVICE)
32
  torch.backends.cudnn.benchmark = True
33
 
34
+ PIPELINE=None
35
+
36
  # -----------------------------
37
  # Model / pipeline loading
38
  # -----------------------------
39
  @torch.no_grad()
40
  @spaces.GPU
41
  def load_pipeline_single_gpu() -> FluxKontextSliderPipeline:
42
+ global PIPELINE
43
+
44
  pretrained = "black-forest-labs/FLUX.1-Kontext-dev"
45
 
46
  n_slider_layers = 4
 
82
  # ------------------------------- --------------------- --------------------------- #
83
 
84
  # Build full pipeline on CPU; no device_map sharding
85
+ PIPELINE = FluxKontextSliderPipeline.from_pretrained(
86
  pretrained,
87
  transformer=transformer,
88
  slider_projector=slider_projector,
 
93
 
94
  print("loading the pipeline lora weights from: {}".format(trained_models_path))
95
 
96
+ PIPELINE.load_lora_weights(trained_models_path)
97
  print("loaded the pipeline with lora weights from: {}".format(trained_models_path))
 
 
98
 
99
+ # Initializing the pipeline with gpu
100
+ print("INIT pipeline with the gpu")
101
+ load_pipeline_single_gpu()
102
  PIPELINE.to(DEVICE)
103
  print(f"[init] Pipeline loaded on {DEVICE}")
104
 
 
298
  # -----------------------------
299
  # Inference functions
300
  # -----------------------------
 
301
  @spaces.GPU
302
+ @torch.no_grad()
303
  def generate_image_stack_edits(text_prompt, n_edits, input_image):
304
  """
305
  Compute n_edits images on a single GPU for slider values in (0,1],