RishubhPar commited on
Commit
3f41cdf
·
verified ·
1 Parent(s): 461f73f

added file with changes

Browse files
Files changed (1) hide show
  1. app.py +11 -11
app.py CHANGED
@@ -2,6 +2,7 @@ import os
2
  import gc
3
  from typing import List, Tuple, Dict
4
  import json
 
5
 
6
  import torch
7
  import gradio as gr
@@ -26,15 +27,16 @@ if HF_TOKEN:
26
  # Avoid meta-tensor init from environment leftovers
27
  os.environ.pop("ACCELERATE_INIT_EMPTY_WEIGHTS", None)
28
 
29
- DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
30
  print("Using device:", DEVICE)
31
-
32
  torch.backends.cudnn.benchmark = True
33
 
34
  # -----------------------------
35
  # Model / pipeline loading
36
  # -----------------------------
37
- def load_pipeline_single_gpu(device_str: str) -> FluxKontextSliderPipeline:
 
 
38
  pretrained = "black-forest-labs/FLUX.1-Kontext-dev"
39
 
40
  n_slider_layers = 4
@@ -50,7 +52,6 @@ def load_pipeline_single_gpu(device_str: str) -> FluxKontextSliderPipeline:
50
  low_cpu_mem_usage=False,
51
  token=HF_TOKEN,
52
  )
53
- transformer.eval()
54
  weight_dtype = transformer.dtype # keep checkpoint dtype
55
 
56
  # Slider projector
@@ -69,7 +70,7 @@ def load_pipeline_single_gpu(device_str: str) -> FluxKontextSliderPipeline:
69
 
70
  # Load projector weights on CPU
71
  slider_projector_path = os.path.join(trained_models_path, "slider_projector.pth")
72
- state_dict = torch.load(slider_projector_path)
73
  print("state_dict keys: {}".format(state_dict.keys()))
74
 
75
  slider_projector.load_state_dict(state_dict)
@@ -90,13 +91,11 @@ def load_pipeline_single_gpu(device_str: str) -> FluxKontextSliderPipeline:
90
 
91
  pipeline.load_lora_weights(trained_models_path)
92
  print("loaded the pipeline with lora weights from: {}".format(trained_models_path))
93
-
94
- # Move everything to the single device
95
- pipeline.to(device_str)
96
  return pipeline
97
 
98
 
99
- PIPELINE = load_pipeline_single_gpu(DEVICE)
 
100
  print(f"[init] Pipeline loaded on {DEVICE}")
101
 
102
 
@@ -285,7 +284,7 @@ def resize_image(img: Image.Image, target: int = 512) -> Image.Image:
285
  img = img.resize((new_w, new_h), resample)
286
  return img
287
 
288
-
289
  def _encode_prompt(prompt: str):
290
  with torch.no_grad():
291
  pe, ppe, _ = PIPELINE.encode_prompt(prompt, prompt_2=prompt)
@@ -295,6 +294,7 @@ def _encode_prompt(prompt: str):
295
  # -----------------------------
296
  # Inference functions
297
  # -----------------------------
 
298
  def generate_image_stack_edits(text_prompt, n_edits, input_image):
299
  """
300
  Compute n_edits images on a single GPU for slider values in (0,1],
@@ -346,7 +346,7 @@ def generate_image_stack_edits(text_prompt, n_edits, input_image):
346
  first = results[0] if results else None
347
  return results, first
348
 
349
-
350
  def generate_single_image(text_prompt, slider_value, input_image):
351
  if not input_image or not text_prompt or text_prompt.startswith("Please select"):
352
  return None
 
2
  import gc
3
  from typing import List, Tuple, Dict
4
  import json
5
+ import spaces
6
 
7
  import torch
8
  import gradio as gr
 
27
  # Avoid meta-tensor init from environment leftovers
28
  os.environ.pop("ACCELERATE_INIT_EMPTY_WEIGHTS", None)
29
 
30
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
31
  print("Using device:", DEVICE)
 
32
  torch.backends.cudnn.benchmark = True
33
 
34
  # -----------------------------
35
  # Model / pipeline loading
36
  # -----------------------------
37
+ @torch.nograd()
38
+ @spaces.GPU
39
+ def load_pipeline_single_gpu() -> FluxKontextSliderPipeline:
40
  pretrained = "black-forest-labs/FLUX.1-Kontext-dev"
41
 
42
  n_slider_layers = 4
 
52
  low_cpu_mem_usage=False,
53
  token=HF_TOKEN,
54
  )
 
55
  weight_dtype = transformer.dtype # keep checkpoint dtype
56
 
57
  # Slider projector
 
70
 
71
  # Load projector weights on CPU
72
  slider_projector_path = os.path.join(trained_models_path, "slider_projector.pth")
73
+ state_dict = torch.load(slider_projector_path, map_location='cpu')
74
  print("state_dict keys: {}".format(state_dict.keys()))
75
 
76
  slider_projector.load_state_dict(state_dict)
 
91
 
92
  pipeline.load_lora_weights(trained_models_path)
93
  print("loaded the pipeline with lora weights from: {}".format(trained_models_path))
 
 
 
94
  return pipeline
95
 
96
 
97
+ PIPELINE = load_pipeline_single_gpu()
98
+ PIPELINE.to(DEVICE)
99
  print(f"[init] Pipeline loaded on {DEVICE}")
100
 
101
 
 
284
  img = img.resize((new_w, new_h), resample)
285
  return img
286
 
287
+ @spaces.GPU
288
  def _encode_prompt(prompt: str):
289
  with torch.no_grad():
290
  pe, ppe, _ = PIPELINE.encode_prompt(prompt, prompt_2=prompt)
 
294
  # -----------------------------
295
  # Inference functions
296
  # -----------------------------
297
+ @spaces.GPU
298
  def generate_image_stack_edits(text_prompt, n_edits, input_image):
299
  """
300
  Compute n_edits images on a single GPU for slider values in (0,1],
 
346
  first = results[0] if results else None
347
  return results, first
348
 
349
+ @spaces.GPU
350
  def generate_single_image(text_prompt, slider_value, input_image):
351
  if not input_image or not text_prompt or text_prompt.startswith("Please select"):
352
  return None