Spaces:
Runtime error
Runtime error
small changes
Browse files
app.py
CHANGED
|
@@ -32,74 +32,98 @@ PIPELINE=None
|
|
| 32 |
# Model / pipeline loading
|
| 33 |
# -----------------------------
|
| 34 |
@spaces.GPU
|
| 35 |
-
def load_pipeline_single_gpu()
|
|
|
|
| 36 |
global PIPELINE
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 41 |
-
print("Using device:", DEVICE)
|
| 42 |
-
torch.backends.cudnn.benchmark = True
|
| 43 |
-
|
| 44 |
-
n_slider_layers = 4
|
| 45 |
-
slider_projector_out_dim = 6144
|
| 46 |
-
trained_models_path = "./model_weights/"
|
| 47 |
-
is_clip_input = True
|
| 48 |
-
|
| 49 |
-
# Load transformer fully on CPU; avoid meta tensors
|
| 50 |
-
transformer = FluxTransformer2DModelwithSliderConditioning.from_pretrained(
|
| 51 |
-
pretrained,
|
| 52 |
-
subfolder="transformer",
|
| 53 |
-
device_map=None,
|
| 54 |
-
low_cpu_mem_usage=False,
|
| 55 |
-
token=HF_TOKEN,
|
| 56 |
-
)
|
| 57 |
-
weight_dtype = transformer.dtype # keep checkpoint dtype
|
| 58 |
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
)
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
)
|
| 68 |
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
# Load projector weights on CPU
|
| 74 |
-
slider_projector_path = os.path.join(trained_models_path, "slider_projector.pth")
|
| 75 |
-
state_dict = torch.load(slider_projector_path, map_location='cpu')
|
| 76 |
-
print("state_dict keys: {}".format(state_dict.keys()))
|
| 77 |
-
|
| 78 |
-
slider_projector.load_state_dict(state_dict)
|
| 79 |
-
print(f"loaded slider_projector from {slider_projector_path}")
|
| 80 |
-
# ------------------------------- --------------------- --------------------------- #
|
| 81 |
-
|
| 82 |
-
# Build full pipeline on CPU; no device_map sharding
|
| 83 |
-
pipe = FluxKontextSliderPipeline.from_pretrained(
|
| 84 |
-
pretrained,
|
| 85 |
-
transformer=transformer,
|
| 86 |
-
slider_projector=slider_projector,
|
| 87 |
-
torch_dtype=weight_dtype,
|
| 88 |
-
device_map=None,
|
| 89 |
-
low_cpu_mem_usage=False,
|
| 90 |
-
)
|
| 91 |
|
| 92 |
-
|
|
|
|
| 93 |
|
| 94 |
-
|
| 95 |
-
|
|
|
|
|
|
|
| 96 |
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
|
|
|
|
|
|
|
|
|
| 100 |
# Initializing the pipeline with gpu
|
| 101 |
print("INIT pipeline with the gpu")
|
| 102 |
-
load_pipeline_single_gpu()
|
|
|
|
| 103 |
|
| 104 |
# -----------------------------
|
| 105 |
# Sample Images & Precomputed Results
|
|
|
|
| 32 |
# Model / pipeline loading
|
| 33 |
# -----------------------------
|
| 34 |
@spaces.GPU
|
| 35 |
+
def load_pipeline_single_gpu():
|
| 36 |
+
"""Initialize PIPELINE inside the ZeroGPU worker and return a small status string."""
|
| 37 |
global PIPELINE
|
| 38 |
+
if PIPELINE is not None:
|
| 39 |
+
print("[worker] PIPELINE already initialized; skipping.")
|
| 40 |
+
return "warm"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
+
try:
|
| 43 |
+
# --- worker-local env & device ---
|
| 44 |
+
os.environ.pop("ACCELERATE_INIT_EMPTY_WEIGHTS", None)
|
| 45 |
+
HF_TOKEN = os.environ.get("HF_TOKEN")
|
| 46 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 47 |
+
print("[worker] cuda available:", torch.cuda.is_available())
|
| 48 |
+
if device == "cuda":
|
| 49 |
+
torch.backends.cudnn.benchmark = True
|
| 50 |
+
|
| 51 |
+
# --- config ---
|
| 52 |
+
pretrained = "black-forest-labs/FLUX.1-Kontext-dev"
|
| 53 |
+
n_slider_layers = 4
|
| 54 |
+
slider_projector_out_dim = 6144
|
| 55 |
+
trained_models_path = "./model_weights/"
|
| 56 |
+
is_clip_input = True
|
| 57 |
+
|
| 58 |
+
# --- validate files before loading ---
|
| 59 |
+
projector_path = os.path.join(trained_models_path, "slider_projector.pth")
|
| 60 |
+
if not os.path.isfile(projector_path):
|
| 61 |
+
return f"error: missing projector weights at {projector_path}"
|
| 62 |
+
if not os.path.isdir(trained_models_path):
|
| 63 |
+
return f"error: missing dir {trained_models_path}"
|
| 64 |
+
|
| 65 |
+
# --- transformer on CPU first ---
|
| 66 |
+
transformer = FluxTransformer2DModelwithSliderConditioning.from_pretrained(
|
| 67 |
+
pretrained,
|
| 68 |
+
subfolder="transformer",
|
| 69 |
+
device_map=None,
|
| 70 |
+
low_cpu_mem_usage=False,
|
| 71 |
+
token=HF_TOKEN, # ok if None for public repos
|
| 72 |
+
# trust_remote_code=True, # uncomment if this model requires it
|
| 73 |
)
|
| 74 |
+
weight_dtype = transformer.dtype
|
| 75 |
+
|
| 76 |
+
# --- projector ---
|
| 77 |
+
if is_clip_input:
|
| 78 |
+
slider_projector = SliderProjector(
|
| 79 |
+
out_dim=slider_projector_out_dim, pe_dim=2, n_layers=n_slider_layers, is_clip_input=True
|
| 80 |
+
)
|
| 81 |
+
else:
|
| 82 |
+
slider_projector = SliderProjector_wo_clip(
|
| 83 |
+
out_dim=slider_projector_out_dim, pe_dim=2, n_layers=n_slider_layers
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
transformer.eval()
|
| 87 |
+
slider_projector.eval()
|
| 88 |
+
|
| 89 |
+
# --- load projector weights (CPU) ---
|
| 90 |
+
state_dict = torch.load(projector_path, map_location="cpu")
|
| 91 |
+
# small print (avoid dumping huge keys)
|
| 92 |
+
print("[worker] projector keys sample:", list(state_dict.keys())[:5])
|
| 93 |
+
slider_projector.load_state_dict(state_dict)
|
| 94 |
+
|
| 95 |
+
# --- build pipeline (CPU) ---
|
| 96 |
+
pipe = FluxKontextSliderPipeline.from_pretrained(
|
| 97 |
+
pretrained,
|
| 98 |
+
transformer=transformer,
|
| 99 |
+
slider_projector=slider_projector,
|
| 100 |
+
torch_dtype=weight_dtype,
|
| 101 |
+
device_map=None,
|
| 102 |
+
low_cpu_mem_usage=False,
|
| 103 |
)
|
| 104 |
|
| 105 |
+
# --- LoRA load (still in worker) ---
|
| 106 |
+
print("[worker] loading LoRA from:", trained_models_path)
|
| 107 |
+
pipe.load_lora_weights(trained_models_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
|
| 109 |
+
# --- move to worker's device ---
|
| 110 |
+
pipe.to(device)
|
| 111 |
|
| 112 |
+
# keep in worker-global
|
| 113 |
+
PIPELINE = pipe
|
| 114 |
+
print("[worker] PIPELINE ready on", device)
|
| 115 |
+
return "ok"
|
| 116 |
|
| 117 |
+
except Exception:
|
| 118 |
+
tb = traceback.format_exc()
|
| 119 |
+
print("[worker] exception during init:\n", tb)
|
| 120 |
+
# Return the text so you can see it in Space logs
|
| 121 |
+
return "error:\n" + tb
|
| 122 |
+
|
| 123 |
# Initializing the pipeline with gpu
|
| 124 |
print("INIT pipeline with the gpu")
|
| 125 |
+
status = load_pipeline_single_gpu()
|
| 126 |
+
print("[main] worker init status:", status)
|
| 127 |
|
| 128 |
# -----------------------------
|
| 129 |
# Sample Images & Precomputed Results
|