RishubhPar commited on
Commit
c1eeec0
·
verified ·
1 Parent(s): f2201b8

small changes

Browse files
Files changed (1) hide show
  1. app.py +82 -58
app.py CHANGED
@@ -32,74 +32,98 @@ PIPELINE=None
32
  # Model / pipeline loading
33
  # -----------------------------
34
  @spaces.GPU
35
- def load_pipeline_single_gpu() -> FluxKontextSliderPipeline:
 
36
  global PIPELINE
37
-
38
- pretrained = "black-forest-labs/FLUX.1-Kontext-dev"
39
-
40
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
41
- print("Using device:", DEVICE)
42
- torch.backends.cudnn.benchmark = True
43
-
44
- n_slider_layers = 4
45
- slider_projector_out_dim = 6144
46
- trained_models_path = "./model_weights/"
47
- is_clip_input = True
48
-
49
- # Load transformer fully on CPU; avoid meta tensors
50
- transformer = FluxTransformer2DModelwithSliderConditioning.from_pretrained(
51
- pretrained,
52
- subfolder="transformer",
53
- device_map=None,
54
- low_cpu_mem_usage=False,
55
- token=HF_TOKEN,
56
- )
57
- weight_dtype = transformer.dtype # keep checkpoint dtype
58
 
59
- # Slider projector
60
- if is_clip_input:
61
- slider_projector = SliderProjector(
62
- out_dim=slider_projector_out_dim, pe_dim=2, n_layers=n_slider_layers, is_clip_input=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  )
64
- else:
65
- slider_projector = SliderProjector_wo_clip(
66
- out_dim=slider_projector_out_dim, pe_dim=2, n_layers=n_slider_layers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  )
68
 
69
- # putting both the models to infer
70
- transformer.eval()
71
- slider_projector.eval()
72
-
73
- # Load projector weights on CPU
74
- slider_projector_path = os.path.join(trained_models_path, "slider_projector.pth")
75
- state_dict = torch.load(slider_projector_path, map_location='cpu')
76
- print("state_dict keys: {}".format(state_dict.keys()))
77
-
78
- slider_projector.load_state_dict(state_dict)
79
- print(f"loaded slider_projector from {slider_projector_path}")
80
- # ------------------------------- --------------------- --------------------------- #
81
-
82
- # Build full pipeline on CPU; no device_map sharding
83
- pipe = FluxKontextSliderPipeline.from_pretrained(
84
- pretrained,
85
- transformer=transformer,
86
- slider_projector=slider_projector,
87
- torch_dtype=weight_dtype,
88
- device_map=None,
89
- low_cpu_mem_usage=False,
90
- )
91
 
92
- print("loading the pipeline lora weights from: {}".format(trained_models_path))
 
93
 
94
- pipe.load_lora_weights(trained_models_path)
95
- print("loaded the pipeline with lora weights from: {}".format(trained_models_path))
 
 
96
 
97
- pipe.to(DEVICE)
98
- PIPELINE = pipe
99
-
 
 
 
100
  # Initializing the pipeline with gpu
101
  print("INIT pipeline with the gpu")
102
- load_pipeline_single_gpu()
 
103
 
104
  # -----------------------------
105
  # Sample Images & Precomputed Results
 
32
  # Model / pipeline loading
33
  # -----------------------------
34
  @spaces.GPU
35
+ def load_pipeline_single_gpu():
36
+ """Initialize PIPELINE inside the ZeroGPU worker and return a small status string."""
37
  global PIPELINE
38
+ if PIPELINE is not None:
39
+ print("[worker] PIPELINE already initialized; skipping.")
40
+ return "warm"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
+ try:
43
+ # --- worker-local env & device ---
44
+ os.environ.pop("ACCELERATE_INIT_EMPTY_WEIGHTS", None)
45
+ HF_TOKEN = os.environ.get("HF_TOKEN")
46
+ device = "cuda" if torch.cuda.is_available() else "cpu"
47
+ print("[worker] cuda available:", torch.cuda.is_available())
48
+ if device == "cuda":
49
+ torch.backends.cudnn.benchmark = True
50
+
51
+ # --- config ---
52
+ pretrained = "black-forest-labs/FLUX.1-Kontext-dev"
53
+ n_slider_layers = 4
54
+ slider_projector_out_dim = 6144
55
+ trained_models_path = "./model_weights/"
56
+ is_clip_input = True
57
+
58
+ # --- validate files before loading ---
59
+ projector_path = os.path.join(trained_models_path, "slider_projector.pth")
60
+ if not os.path.isfile(projector_path):
61
+ return f"error: missing projector weights at {projector_path}"
62
+ if not os.path.isdir(trained_models_path):
63
+ return f"error: missing dir {trained_models_path}"
64
+
65
+ # --- transformer on CPU first ---
66
+ transformer = FluxTransformer2DModelwithSliderConditioning.from_pretrained(
67
+ pretrained,
68
+ subfolder="transformer",
69
+ device_map=None,
70
+ low_cpu_mem_usage=False,
71
+ token=HF_TOKEN, # ok if None for public repos
72
+ # trust_remote_code=True, # uncomment if this model requires it
73
  )
74
+ weight_dtype = transformer.dtype
75
+
76
+ # --- projector ---
77
+ if is_clip_input:
78
+ slider_projector = SliderProjector(
79
+ out_dim=slider_projector_out_dim, pe_dim=2, n_layers=n_slider_layers, is_clip_input=True
80
+ )
81
+ else:
82
+ slider_projector = SliderProjector_wo_clip(
83
+ out_dim=slider_projector_out_dim, pe_dim=2, n_layers=n_slider_layers
84
+ )
85
+
86
+ transformer.eval()
87
+ slider_projector.eval()
88
+
89
+ # --- load projector weights (CPU) ---
90
+ state_dict = torch.load(projector_path, map_location="cpu")
91
+ # small print (avoid dumping huge keys)
92
+ print("[worker] projector keys sample:", list(state_dict.keys())[:5])
93
+ slider_projector.load_state_dict(state_dict)
94
+
95
+ # --- build pipeline (CPU) ---
96
+ pipe = FluxKontextSliderPipeline.from_pretrained(
97
+ pretrained,
98
+ transformer=transformer,
99
+ slider_projector=slider_projector,
100
+ torch_dtype=weight_dtype,
101
+ device_map=None,
102
+ low_cpu_mem_usage=False,
103
  )
104
 
105
+ # --- LoRA load (still in worker) ---
106
+ print("[worker] loading LoRA from:", trained_models_path)
107
+ pipe.load_lora_weights(trained_models_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
+ # --- move to worker's device ---
110
+ pipe.to(device)
111
 
112
+ # keep in worker-global
113
+ PIPELINE = pipe
114
+ print("[worker] PIPELINE ready on", device)
115
+ return "ok"
116
 
117
+ except Exception:
118
+ tb = traceback.format_exc()
119
+ print("[worker] exception during init:\n", tb)
120
+ # Return the text so you can see it in Space logs
121
+ return "error:\n" + tb
122
+
123
  # Initializing the pipeline with gpu
124
  print("INIT pipeline with the gpu")
125
+ status = load_pipeline_single_gpu()
126
+ print("[main] worker init status:", status)
127
 
128
  # -----------------------------
129
  # Sample Images & Precomputed Results