Spaces:
Running
on
Zero
Running
on
Zero
added file with changes
Browse files
app.py
CHANGED
|
@@ -2,6 +2,7 @@ import os
|
|
| 2 |
import gc
|
| 3 |
from typing import List, Tuple, Dict
|
| 4 |
import json
|
|
|
|
| 5 |
|
| 6 |
import torch
|
| 7 |
import gradio as gr
|
|
@@ -26,15 +27,16 @@ if HF_TOKEN:
|
|
| 26 |
# Avoid meta-tensor init from environment leftovers
|
| 27 |
os.environ.pop("ACCELERATE_INIT_EMPTY_WEIGHTS", None)
|
| 28 |
|
| 29 |
-
DEVICE = "cuda
|
| 30 |
print("Using device:", DEVICE)
|
| 31 |
-
|
| 32 |
torch.backends.cudnn.benchmark = True
|
| 33 |
|
| 34 |
# -----------------------------
|
| 35 |
# Model / pipeline loading
|
| 36 |
# -----------------------------
|
| 37 |
-
|
|
|
|
|
|
|
| 38 |
pretrained = "black-forest-labs/FLUX.1-Kontext-dev"
|
| 39 |
|
| 40 |
n_slider_layers = 4
|
|
@@ -50,7 +52,6 @@ def load_pipeline_single_gpu(device_str: str) -> FluxKontextSliderPipeline:
|
|
| 50 |
low_cpu_mem_usage=False,
|
| 51 |
token=HF_TOKEN,
|
| 52 |
)
|
| 53 |
-
transformer.eval()
|
| 54 |
weight_dtype = transformer.dtype # keep checkpoint dtype
|
| 55 |
|
| 56 |
# Slider projector
|
|
@@ -69,7 +70,7 @@ def load_pipeline_single_gpu(device_str: str) -> FluxKontextSliderPipeline:
|
|
| 69 |
|
| 70 |
# Load projector weights on CPU
|
| 71 |
slider_projector_path = os.path.join(trained_models_path, "slider_projector.pth")
|
| 72 |
-
state_dict = torch.load(slider_projector_path)
|
| 73 |
print("state_dict keys: {}".format(state_dict.keys()))
|
| 74 |
|
| 75 |
slider_projector.load_state_dict(state_dict)
|
|
@@ -90,13 +91,11 @@ def load_pipeline_single_gpu(device_str: str) -> FluxKontextSliderPipeline:
|
|
| 90 |
|
| 91 |
pipeline.load_lora_weights(trained_models_path)
|
| 92 |
print("loaded the pipeline with lora weights from: {}".format(trained_models_path))
|
| 93 |
-
|
| 94 |
-
# Move everything to the single device
|
| 95 |
-
pipeline.to(device_str)
|
| 96 |
return pipeline
|
| 97 |
|
| 98 |
|
| 99 |
-
PIPELINE = load_pipeline_single_gpu(
|
|
|
|
| 100 |
print(f"[init] Pipeline loaded on {DEVICE}")
|
| 101 |
|
| 102 |
|
|
@@ -285,7 +284,7 @@ def resize_image(img: Image.Image, target: int = 512) -> Image.Image:
|
|
| 285 |
img = img.resize((new_w, new_h), resample)
|
| 286 |
return img
|
| 287 |
|
| 288 |
-
|
| 289 |
def _encode_prompt(prompt: str):
|
| 290 |
with torch.no_grad():
|
| 291 |
pe, ppe, _ = PIPELINE.encode_prompt(prompt, prompt_2=prompt)
|
|
@@ -295,6 +294,7 @@ def _encode_prompt(prompt: str):
|
|
| 295 |
# -----------------------------
|
| 296 |
# Inference functions
|
| 297 |
# -----------------------------
|
|
|
|
| 298 |
def generate_image_stack_edits(text_prompt, n_edits, input_image):
|
| 299 |
"""
|
| 300 |
Compute n_edits images on a single GPU for slider values in (0,1],
|
|
@@ -346,7 +346,7 @@ def generate_image_stack_edits(text_prompt, n_edits, input_image):
|
|
| 346 |
first = results[0] if results else None
|
| 347 |
return results, first
|
| 348 |
|
| 349 |
-
|
| 350 |
def generate_single_image(text_prompt, slider_value, input_image):
|
| 351 |
if not input_image or not text_prompt or text_prompt.startswith("Please select"):
|
| 352 |
return None
|
|
|
|
| 2 |
import gc
|
| 3 |
from typing import List, Tuple, Dict
|
| 4 |
import json
|
| 5 |
+
import spaces
|
| 6 |
|
| 7 |
import torch
|
| 8 |
import gradio as gr
|
|
|
|
| 27 |
# Avoid meta-tensor init from environment leftovers
|
| 28 |
os.environ.pop("ACCELERATE_INIT_EMPTY_WEIGHTS", None)
|
| 29 |
|
| 30 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 31 |
print("Using device:", DEVICE)
|
|
|
|
| 32 |
torch.backends.cudnn.benchmark = True
|
| 33 |
|
| 34 |
# -----------------------------
|
| 35 |
# Model / pipeline loading
|
| 36 |
# -----------------------------
|
| 37 |
+
@torch.nograd()
|
| 38 |
+
@spaces.GPU
|
| 39 |
+
def load_pipeline_single_gpu() -> FluxKontextSliderPipeline:
|
| 40 |
pretrained = "black-forest-labs/FLUX.1-Kontext-dev"
|
| 41 |
|
| 42 |
n_slider_layers = 4
|
|
|
|
| 52 |
low_cpu_mem_usage=False,
|
| 53 |
token=HF_TOKEN,
|
| 54 |
)
|
|
|
|
| 55 |
weight_dtype = transformer.dtype # keep checkpoint dtype
|
| 56 |
|
| 57 |
# Slider projector
|
|
|
|
| 70 |
|
| 71 |
# Load projector weights on CPU
|
| 72 |
slider_projector_path = os.path.join(trained_models_path, "slider_projector.pth")
|
| 73 |
+
state_dict = torch.load(slider_projector_path, map_location='cpu')
|
| 74 |
print("state_dict keys: {}".format(state_dict.keys()))
|
| 75 |
|
| 76 |
slider_projector.load_state_dict(state_dict)
|
|
|
|
| 91 |
|
| 92 |
pipeline.load_lora_weights(trained_models_path)
|
| 93 |
print("loaded the pipeline with lora weights from: {}".format(trained_models_path))
|
|
|
|
|
|
|
|
|
|
| 94 |
return pipeline
|
| 95 |
|
| 96 |
|
| 97 |
+
PIPELINE = load_pipeline_single_gpu()
|
| 98 |
+
PIPELINE.to(DEVICE)
|
| 99 |
print(f"[init] Pipeline loaded on {DEVICE}")
|
| 100 |
|
| 101 |
|
|
|
|
| 284 |
img = img.resize((new_w, new_h), resample)
|
| 285 |
return img
|
| 286 |
|
| 287 |
+
@spaces.GPU
|
| 288 |
def _encode_prompt(prompt: str):
|
| 289 |
with torch.no_grad():
|
| 290 |
pe, ppe, _ = PIPELINE.encode_prompt(prompt, prompt_2=prompt)
|
|
|
|
| 294 |
# -----------------------------
|
| 295 |
# Inference functions
|
| 296 |
# -----------------------------
|
| 297 |
+
@spaces.GPU
|
| 298 |
def generate_image_stack_edits(text_prompt, n_edits, input_image):
|
| 299 |
"""
|
| 300 |
Compute n_edits images on a single GPU for slider values in (0,1],
|
|
|
|
| 346 |
first = results[0] if results else None
|
| 347 |
return results, first
|
| 348 |
|
| 349 |
+
@spaces.GPU
|
| 350 |
def generate_single_image(text_prompt, slider_value, input_image):
|
| 351 |
if not input_image or not text_prompt or text_prompt.startswith("Please select"):
|
| 352 |
return None
|