RishubhPar commited on
Commit
b9e9d7e
·
verified ·
1 Parent(s): 7d5d1e3

changes for initializing the pipeline outside the inference and calling it with decorator.

Browse files
Files changed (1) hide show
  1. app.py +170 -97
app.py CHANGED
@@ -27,7 +27,6 @@ if HF_TOKEN:
27
  # -----------------------------
28
  # Avoid meta-tensor init from environment leftovers
29
  os.environ.pop("ACCELERATE_INIT_EMPTY_WEIGHTS", None)
30
- PIPELINE=None
31
 
32
  # -----------------------------
33
  # Model / pipeline loading
@@ -35,101 +34,181 @@ PIPELINE=None
35
 
36
  def _log(msg): print(msg, flush=True)
37
 
38
- def load_pipeline_single_gpu():
39
- global PIPELINE
40
- if PIPELINE is not None:
41
- _log("[worker] PIPELINE already initialized; skipping.")
42
- return "warm"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
- try:
45
- os.environ.pop("ACCELERATE_INIT_EMPTY_WEIGHTS", None)
46
- token = os.environ.get("HF_TOKEN")
47
- cuda_ok = torch.cuda.is_available()
48
- _log(f"[worker] cuda available: {cuda_ok}")
49
- if cuda_ok:
50
- torch.backends.cudnn.benchmark = True
51
-
52
- # ---------- config ----------
53
- pretrained = "black-forest-labs/FLUX.1-Kontext-dev"
54
- trained_models_path = "./model_weights/"
55
- projector_path = os.path.join(trained_models_path, "slider_projector.pth")
56
- offload_dir = "/tmp/offload"; os.makedirs(offload_dir, exist_ok=True)
57
-
58
- if not os.path.isdir(trained_models_path):
59
- return f"error: missing dir {trained_models_path}"
60
- if not os.path.isfile(projector_path):
61
- return f"error: missing projector weights at {projector_path}"
62
-
63
- # dtype selection to cut memory
64
- if cuda_ok and torch.cuda.get_device_capability(0)[0] >= 8:
65
- dtype = torch.bfloat16
66
- elif cuda_ok:
67
- dtype = torch.float16
68
- else:
69
- dtype = torch.float32
70
-
71
- max_memory = {"cuda": "80GiB", "cpu": "60GiB"} # tune if needed
72
-
73
- _log("[worker] loading transformer (sharded/offloaded)…")
74
- transformer = FluxTransformer2DModelwithSliderConditioning.from_pretrained(
75
- pretrained,
76
- subfolder="transformer",
77
- token=token,
78
- trust_remote_code=True,
79
- torch_dtype=dtype,
80
- low_cpu_mem_usage=True,
81
- # device_map="balanced_low_0",
82
- offload_folder=offload_dir,
83
- offload_state_dict=True,
84
- # max_memory=max_memory,
85
- )
86
- weight_dtype = transformer.dtype
87
- _log(f"[worker] transformer loaded, dtype={weight_dtype}")
88
-
89
- _log("[worker] building slider projector…")
90
- slider_projector = SliderProjector(out_dim=6144, pe_dim=2, n_layers=4, is_clip_input=True)
91
- slider_projector.eval()
92
- _log("[worker] loading projector weights…")
93
- state_dict = torch.load(projector_path, map_location="cpu", weights_only=True)
94
- slider_projector.load_state_dict(state_dict, strict=True)
95
-
96
- _log("[worker] assembling pipeline (sharded/offloaded)…")
97
- pipe = FluxKontextSliderPipeline.from_pretrained(
98
- pretrained,
99
- token=token,
100
- trust_remote_code=True,
101
- transformer=transformer,
102
- slider_projector=slider_projector,
103
- torch_dtype=weight_dtype,
104
- low_cpu_mem_usage=True,
105
- # device_map="balanced_low_0",
106
- offload_folder=offload_dir,
107
- offload_state_dict=True,
108
- # max_memory=max_memory,
109
- )
110
- _log("[worker] pipeline assembled.")
111
-
112
- _log(f"[worker] loading LoRA from: {trained_models_path}")
113
- pipe.load_lora_weights(trained_models_path)
114
- _log("[worker] LoRA loaded.")
115
 
116
- # DO NOT pipe.to("cuda") here; keep auto device_map to avoid OOM
117
- PIPELINE = pipe
118
- if cuda_ok:
119
- free, total = torch.cuda.mem_get_info()
120
- _log(f"[worker] VRAM free/total: {free/1e9:.2f}/{total/1e9:.2f} GB")
121
- _log("[worker] PIPELINE ready.")
122
- return "ok"
123
 
124
- except Exception:
125
- _log("[worker] init exception:\n" + traceback.format_exc())
126
- return "error"
127
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
 
129
  # -----------------------------
130
  # Sample Images & Precomputed Results
131
  # -----------------------------
132
-
133
  def create_sample_entry(name, image_filename, prompt, result_folder, num_results=5, result_pattern="image_{i}.png", precomputed_base="./sample_images/precomputed"):
134
  """
135
  Helper function to create a sample entry with subfolder organization.
@@ -314,7 +393,7 @@ def resize_image(img: Image.Image, target: int = 512) -> Image.Image:
314
  # -----------------------------
315
  # Inference functions
316
  # -----------------------------
317
- @spaces.GPU
318
  @torch.no_grad()
319
  def generate_image_stack_edits(text_prompt, n_edits, input_image):
320
  """
@@ -323,13 +402,7 @@ def generate_image_stack_edits(text_prompt, n_edits, input_image):
323
  """
324
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
325
 
326
- # if pipeline is null will initialize it simply.
327
- global PIPELINE
328
- if PIPELINE is None:
329
- status = load_pipeline_single_gpu()
330
-
331
- print("loaded pipeline status: {}".format(status))
332
-
333
  if not input_image or not text_prompt or text_prompt.startswith("Please select"):
334
  return [], None
335
 
@@ -376,7 +449,7 @@ def generate_image_stack_edits(text_prompt, n_edits, input_image):
376
  first = results[0] if results else None
377
  return results, first
378
 
379
- @spaces.GPU
380
  def generate_single_image(text_prompt, slider_value, input_image):
381
  if not input_image or not text_prompt or text_prompt.startswith("Please select"):
382
  return None
 
27
  # -----------------------------
28
  # Avoid meta-tensor init from environment leftovers
29
  os.environ.pop("ACCELERATE_INIT_EMPTY_WEIGHTS", None)
 
30
 
31
  # -----------------------------
32
  # Model / pipeline loading
 
34
 
35
  def _log(msg): print(msg, flush=True)
36
 
37
+ # def load_pipeline_single_gpu():
38
+ # global PIPELINE
39
+ # if PIPELINE is not None:
40
+ # _log("[worker] PIPELINE already initialized; skipping.")
41
+ # return "warm"
42
+
43
+ # try:
44
+ # os.environ.pop("ACCELERATE_INIT_EMPTY_WEIGHTS", None)
45
+ # token = os.environ.get("HF_TOKEN")
46
+ # cuda_ok = torch.cuda.is_available()
47
+ # _log(f"[worker] cuda available: {cuda_ok}")
48
+ # if cuda_ok:
49
+ # torch.backends.cudnn.benchmark = True
50
+
51
+ # # ---------- config ----------
52
+ # pretrained = "black-forest-labs/FLUX.1-Kontext-dev"
53
+ # trained_models_path = "./model_weights/"
54
+ # projector_path = os.path.join(trained_models_path, "slider_projector.pth")
55
+ # offload_dir = "/tmp/offload"; os.makedirs(offload_dir, exist_ok=True)
56
+
57
+ # if not os.path.isdir(trained_models_path):
58
+ # return f"error: missing dir {trained_models_path}"
59
+ # if not os.path.isfile(projector_path):
60
+ # return f"error: missing projector weights at {projector_path}"
61
+
62
+ # # dtype selection to cut memory
63
+ # if cuda_ok and torch.cuda.get_device_capability(0)[0] >= 8:
64
+ # dtype = torch.bfloat16
65
+ # elif cuda_ok:
66
+ # dtype = torch.float16
67
+ # else:
68
+ # dtype = torch.float32
69
+
70
+ # max_memory = {"cuda": "80GiB", "cpu": "60GiB"} # tune if needed
71
+
72
+ # _log("[worker] loading transformer (sharded/offloaded)…")
73
+ # transformer = FluxTransformer2DModelwithSliderConditioning.from_pretrained(
74
+ # pretrained,
75
+ # subfolder="transformer",
76
+ # token=token,
77
+ # trust_remote_code=True,
78
+ # torch_dtype=dtype,
79
+ # low_cpu_mem_usage=True,
80
+ # # device_map="balanced_low_0",
81
+ # offload_folder=offload_dir,
82
+ # offload_state_dict=True,
83
+ # # max_memory=max_memory,
84
+ # )
85
+ # weight_dtype = transformer.dtype
86
+ # _log(f"[worker] transformer loaded, dtype={weight_dtype}")
87
+
88
+ # _log("[worker] building slider projector…")
89
+ # slider_projector = SliderProjector(out_dim=6144, pe_dim=2, n_layers=4, is_clip_input=True)
90
+ # slider_projector.eval()
91
+ # _log("[worker] loading projector weights…")
92
+ # state_dict = torch.load(projector_path, map_location="cpu", weights_only=True)
93
+ # slider_projector.load_state_dict(state_dict, strict=True)
94
+
95
+ # _log("[worker] assembling pipeline (sharded/offloaded)…")
96
+ # pipe = FluxKontextSliderPipeline.from_pretrained(
97
+ # pretrained,
98
+ # token=token,
99
+ # trust_remote_code=True,
100
+ # transformer=transformer,
101
+ # slider_projector=slider_projector,
102
+ # torch_dtype=weight_dtype,
103
+ # low_cpu_mem_usage=True,
104
+ # # device_map="balanced_low_0",
105
+ # offload_folder=offload_dir,
106
+ # offload_state_dict=True,
107
+ # # max_memory=max_memory,
108
+ # )
109
+ # _log("[worker] pipeline assembled.")
110
+
111
+ # _log(f"[worker] loading LoRA from: {trained_models_path}")
112
+ # pipe.load_lora_weights(trained_models_path)
113
+ # _log("[worker] LoRA loaded.")
114
+
115
+ # # DO NOT pipe.to("cuda") here; keep auto device_map to avoid OOM
116
+ # PIPELINE = pipe
117
+ # if cuda_ok:
118
+ # free, total = torch.cuda.mem_get_info()
119
+ # _log(f"[worker] VRAM free/total: {free/1e9:.2f}/{total/1e9:.2f} GB")
120
+ # _log("[worker] PIPELINE ready.")
121
+ # return "ok"
122
+
123
+ # except Exception:
124
+ # _log("[worker] init exception:\n" + traceback.format_exc())
125
+ # return "error"
126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
 
 
 
 
 
 
 
128
 
 
 
 
129
 
130
+ # -----------------------------
131
+ # Loading the pipeline without any function so that it will be called directly in the inference
132
+ # -----------------------------
133
+ os.environ.pop("ACCELERATE_INIT_EMPTY_WEIGHTS", None)
134
+ token = os.environ.get("HF_TOKEN")
135
+ cuda_ok = torch.cuda.is_available()
136
+ _log(f"[worker] cuda available: {cuda_ok}")
137
+ if cuda_ok:
138
+ torch.backends.cudnn.benchmark = True
139
+
140
+ # ---------- config ----------
141
+ pretrained = "black-forest-labs/FLUX.1-Kontext-dev"
142
+ trained_models_path = "./model_weights/"
143
+ projector_path = os.path.join(trained_models_path, "slider_projector.pth")
144
+ offload_dir = "/tmp/offload"; os.makedirs(offload_dir, exist_ok=True)
145
+
146
+ # dtype selection to cut memory
147
+ if cuda_ok and torch.cuda.get_device_capability(0)[0] >= 8:
148
+ dtype = torch.bfloat16
149
+ elif cuda_ok:
150
+ dtype = torch.float16
151
+ else:
152
+ dtype = torch.float32
153
+
154
+ max_memory = {"cuda": "80GiB", "cpu": "60GiB"} # tune if needed
155
+
156
+ _log("[worker] loading transformer (sharded/offloaded)…")
157
+ transformer = FluxTransformer2DModelwithSliderConditioning.from_pretrained(
158
+ pretrained,
159
+ subfolder="transformer",
160
+ token=token,
161
+ trust_remote_code=True,
162
+ torch_dtype=dtype,
163
+ low_cpu_mem_usage=True,
164
+ # device_map="balanced_low_0",
165
+ offload_folder=offload_dir,
166
+ offload_state_dict=True,
167
+ # max_memory=max_memory,
168
+ )
169
+ weight_dtype = transformer.dtype
170
+ _log(f"[worker] transformer loaded, dtype={weight_dtype}")
171
+
172
+ _log("[worker] building slider projector…")
173
+ slider_projector = SliderProjector(out_dim=6144, pe_dim=2, n_layers=4, is_clip_input=True)
174
+ slider_projector.eval()
175
+ _log("[worker] loading projector weights…")
176
+ state_dict = torch.load(projector_path, map_location="cpu", weights_only=True)
177
+ slider_projector.load_state_dict(state_dict, strict=True)
178
+
179
+ _log("[worker] assembling pipeline (sharded/offloaded)…")
180
+ pipe = FluxKontextSliderPipeline.from_pretrained(
181
+ pretrained,
182
+ token=token,
183
+ trust_remote_code=True,
184
+ transformer=transformer,
185
+ slider_projector=slider_projector,
186
+ torch_dtype=weight_dtype,
187
+ low_cpu_mem_usage=True,
188
+ # device_map="balanced_low_0",
189
+ offload_folder=offload_dir,
190
+ offload_state_dict=True,
191
+ # max_memory=max_memory,
192
+ )
193
+ _log("[worker] pipeline assembled.")
194
+
195
+ _log(f"[worker] loading LoRA from: {trained_models_path}")
196
+ pipe.load_lora_weights(trained_models_path)
197
+ _log("[worker] LoRA loaded.")
198
+
199
+ # DO NOT pipe.to("cuda") here; keep auto device_map to avoid OOM
200
+ PIPELINE = pipe
201
+ if cuda_ok:
202
+ free, total = torch.cuda.mem_get_info()
203
+ _log(f"[worker] VRAM free/total: {free/1e9:.2f}/{total/1e9:.2f} GB")
204
+ _log("[worker] PIPELINE ready.")
205
+
206
+ # moving the pipeline to GPU
207
+ PIPELINE.to('cuda')
208
 
209
  # -----------------------------
210
  # Sample Images & Precomputed Results
211
  # -----------------------------
 
212
  def create_sample_entry(name, image_filename, prompt, result_folder, num_results=5, result_pattern="image_{i}.png", precomputed_base="./sample_images/precomputed"):
213
  """
214
  Helper function to create a sample entry with subfolder organization.
 
393
  # -----------------------------
394
  # Inference functions
395
  # -----------------------------
396
+ @spaces.GPU(duration=500)
397
  @torch.no_grad()
398
  def generate_image_stack_edits(text_prompt, n_edits, input_image):
399
  """
 
402
  """
403
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
404
 
405
+ # pipelien will be loaded already in the global context and will be called here
 
 
 
 
 
 
406
  if not input_image or not text_prompt or text_prompt.startswith("Please select"):
407
  return [], None
408
 
 
449
  first = results[0] if results else None
450
  return results, first
451
 
452
+ @spaces.GPU(duration=80)
453
  def generate_single_image(text_prompt, slider_value, input_image):
454
  if not input_image or not text_prompt or text_prompt.startswith("Please select"):
455
  return None