RishubhPar commited on
Commit
262c23d
·
verified ·
1 Parent(s): fabf8be

small changes.

Browse files
Files changed (1) hide show
  1. app.py +9 -5
app.py CHANGED
@@ -31,12 +31,16 @@ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
31
  print("Using device:", DEVICE)
32
  torch.backends.cudnn.benchmark = True
33
 
 
 
34
  # -----------------------------
35
  # Model / pipeline loading
36
  # -----------------------------
37
  @torch.no_grad()
38
  @spaces.GPU
39
  def load_pipeline_single_gpu() -> FluxKontextSliderPipeline:
 
 
40
  pretrained = "black-forest-labs/FLUX.1-Kontext-dev"
41
 
42
  n_slider_layers = 4
@@ -78,7 +82,7 @@ def load_pipeline_single_gpu() -> FluxKontextSliderPipeline:
78
  # ------------------------------- --------------------- --------------------------- #
79
 
80
  # Build full pipeline on CPU; no device_map sharding
81
- pipeline = FluxKontextSliderPipeline.from_pretrained(
82
  pretrained,
83
  transformer=transformer,
84
  slider_projector=slider_projector,
@@ -89,12 +93,12 @@ def load_pipeline_single_gpu() -> FluxKontextSliderPipeline:
89
 
90
  print("loading the pipeline lora weights from: {}".format(trained_models_path))
91
 
92
- pipeline.load_lora_weights(trained_models_path)
93
  print("loaded the pipeline with lora weights from: {}".format(trained_models_path))
94
- return pipeline
95
-
96
 
97
- PIPELINE = load_pipeline_single_gpu()
 
 
98
  PIPELINE.to(DEVICE)
99
  print(f"[init] Pipeline loaded on {DEVICE}")
100
 
 
31
  print("Using device:", DEVICE)
32
  torch.backends.cudnn.benchmark = True
33
 
34
+ PIPELINE=None
35
+
36
  # -----------------------------
37
  # Model / pipeline loading
38
  # -----------------------------
39
  @torch.no_grad()
40
  @spaces.GPU
41
  def load_pipeline_single_gpu() -> FluxKontextSliderPipeline:
42
+ global PIPELINE
43
+
44
  pretrained = "black-forest-labs/FLUX.1-Kontext-dev"
45
 
46
  n_slider_layers = 4
 
82
  # ------------------------------- --------------------- --------------------------- #
83
 
84
  # Build full pipeline on CPU; no device_map sharding
85
+ PIPELINE = FluxKontextSliderPipeline.from_pretrained(
86
  pretrained,
87
  transformer=transformer,
88
  slider_projector=slider_projector,
 
93
 
94
  print("loading the pipeline lora weights from: {}".format(trained_models_path))
95
 
96
+ PIPELINE.load_lora_weights(trained_models_path)
97
  print("loaded the pipeline with lora weights from: {}".format(trained_models_path))
 
 
98
 
99
+ # Initializing the pipeline with gpu
100
+ print("INIT pipeline with the gpu")
101
+ load_pipeline_single_gpu()
102
  PIPELINE.to(DEVICE)
103
  print(f"[init] Pipeline loaded on {DEVICE}")
104