Spaces:
Running on Zero
Running on Zero
changes
Browse files
app.py
CHANGED
|
@@ -9,7 +9,6 @@ import logging
|
|
| 9 |
from diffusers import DiffusionPipeline, AutoencoderTiny, AutoencoderKL, AutoPipelineForImage2Image
|
| 10 |
from huggingface_hub import login
|
| 11 |
from diffusers.utils import load_image
|
| 12 |
-
#from lora_loading_patch import load_lora_into_transformer
|
| 13 |
import time
|
| 14 |
from datetime import datetime
|
| 15 |
from io import BytesIO
|
|
@@ -27,28 +26,28 @@ from huggingface_hub import hf_hub_download
|
|
| 27 |
from diffusers.quantizers import PipelineQuantizationConfig
|
| 28 |
from diffusers import (FluxPriorReduxPipeline, FluxInpaintPipeline, FluxFillPipeline, FluxKontextPipeline, FluxPipeline)
|
| 29 |
|
|
|
|
| 30 |
# Login Hugging Face Hub
|
| 31 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
| 32 |
login(token=HF_TOKEN)
|
| 33 |
import diffusers
|
| 34 |
|
|
|
|
| 35 |
# init
|
| 36 |
dtype = torch.bfloat16
|
| 37 |
device = "cuda:0"
|
| 38 |
|
|
|
|
| 39 |
base_model = "black-forest-labs/FLUX.1-Krea-dev"
|
| 40 |
|
| 41 |
-
# pipeline_quant_config = PipelineQuantizationConfig(
|
| 42 |
-
# quant_backend="bitsandbytes_4bit",
|
| 43 |
-
# quant_kwargs={"load_in_4bit": True, "bnb_4bit_quant_type": "nf4", "bnb_4bit_compute_dtype": torch.bfloat16},
|
| 44 |
-
# components_to_quantize=["transformer", "text_encoder_2"],
|
| 45 |
-
# )
|
| 46 |
|
| 47 |
txt2img_pipe = FluxKontextPipeline.from_pretrained(base_model, torch_dtype=dtype)
|
| 48 |
txt2img_pipe = txt2img_pipe.to(device)
|
| 49 |
|
|
|
|
| 50 |
MAX_SEED = 2**32 - 1
|
| 51 |
|
|
|
|
| 52 |
class calculateDuration:
|
| 53 |
def __init__(self, activity_name=""):
|
| 54 |
self.activity_name = activity_name
|
|
@@ -69,13 +68,14 @@ class calculateDuration:
|
|
| 69 |
else:
|
| 70 |
print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
|
| 71 |
|
|
|
|
| 72 |
def safe_trim_for_clip(text: str, max_words: int = 77) -> str:
|
| 73 |
-
# 简单按词裁,不破坏主 prompt。你也可以做更智能的关键词抽取。
|
| 74 |
tokens = re.split(r"\s+", text.strip())
|
| 75 |
if len(tokens) <= max_words:
|
| 76 |
return text
|
| 77 |
return " ".join(tokens[:max_words])
|
| 78 |
|
|
|
|
| 79 |
def upload_image_to_r2(image, account_id, access_key, secret_key, bucket_name):
|
| 80 |
with calculateDuration("Upload images"):
|
| 81 |
connectionUrl = f"https://{account_id}.r2.cloudflarestorage.com"
|
|
@@ -94,17 +94,16 @@ def upload_image_to_r2(image, account_id, access_key, secret_key, bucket_name):
|
|
| 94 |
buffer.seek(0)
|
| 95 |
s3.upload_fileobj(buffer, bucket_name, image_file)
|
| 96 |
print("upload finish", image_file)
|
| 97 |
-
|
|
|
|
| 98 |
thumbnail = image.copy()
|
| 99 |
thumbnail_width = 256
|
| 100 |
aspect_ratio = image.height / image.width
|
| 101 |
thumbnail_height = int(thumbnail_width * aspect_ratio)
|
| 102 |
thumbnail = thumbnail.resize((thumbnail_width, thumbnail_height), Image.LANCZOS)
|
| 103 |
|
| 104 |
-
# Generate the thumbnail image filename
|
| 105 |
thumbnail_file = image_file.replace(".png", "_thumbnail.png")
|
| 106 |
|
| 107 |
-
# Save thumbnail to buffer and upload
|
| 108 |
thumbnail_buffer = BytesIO()
|
| 109 |
thumbnail.save(thumbnail_buffer, "PNG")
|
| 110 |
thumbnail_buffer.seek(0)
|
|
@@ -112,9 +111,11 @@ def upload_image_to_r2(image, account_id, access_key, secret_key, bucket_name):
|
|
| 112 |
print("upload thumbnail finish", thumbnail_file)
|
| 113 |
return image_file
|
| 114 |
|
|
|
|
| 115 |
def generate_random_4_digit_string():
|
| 116 |
return ''.join(random.choices(string.digits, k=4))
|
| 117 |
|
|
|
|
| 118 |
@spaces.GPU(duration=120)
|
| 119 |
def run_lora(
|
| 120 |
prompt,
|
|
@@ -152,7 +153,6 @@ def run_lora(
|
|
| 152 |
try:
|
| 153 |
pipe.unload_lora_weights()
|
| 154 |
except Exception as _:
|
| 155 |
-
# 某些版本上未加载时调用可能抛异常,忽略
|
| 156 |
pass
|
| 157 |
|
| 158 |
adapter_names = []
|
|
@@ -170,7 +170,6 @@ def run_lora(
|
|
| 170 |
for lora_info in lora_configs:
|
| 171 |
repo = lora_info.get("repo")
|
| 172 |
weights = lora_info.get("weights")
|
| 173 |
-
# 优先使用用户提供的 adapter_name;没有则随机
|
| 174 |
adapter_name = lora_info.get("adapter_name") or f"adp_{generate_random_4_digit_string()}"
|
| 175 |
weight = float(lora_info.get("adapter_weight", 1.0))
|
| 176 |
if not (repo and weights):
|
|
@@ -178,7 +177,6 @@ def run_lora(
|
|
| 178 |
continue
|
| 179 |
try:
|
| 180 |
weight_path = hf_hub_download(repo_id=repo, filename=weights)
|
| 181 |
-
# 关键修复:prefix=None,避免仅在 text_encoder 查找
|
| 182 |
pipe.load_lora_weights(weight_path, adapter_name=adapter_name, prefix=None)
|
| 183 |
adapter_names.append(adapter_name)
|
| 184 |
adapter_weights.append(weight)
|
|
@@ -187,8 +185,6 @@ def run_lora(
|
|
| 187 |
|
| 188 |
if adapter_names:
|
| 189 |
pipe.set_adapters(adapter_names, adapter_weights=adapter_weights)
|
| 190 |
-
# 可选:融合后推理更快,但无法动态调整权重
|
| 191 |
-
# pipe.fuse_lora(adapter_names=adapter_names)
|
| 192 |
try:
|
| 193 |
active = pipe.get_active_adapters() if hasattr(pipe, "get_active_adapters") else []
|
| 194 |
print("Active adapters:", active)
|
|
@@ -202,15 +198,15 @@ def run_lora(
|
|
| 202 |
lora_layer_count += 1
|
| 203 |
print(f"[DEBUG] transformer LoRA layers: {lora_layer_count}")
|
| 204 |
|
| 205 |
-
# 若层数为 0,给出直观警告
|
| 206 |
if lora_layer_count == 0:
|
| 207 |
gr.Warning("LoRA seems not injected (0 layers on transformer). Check whether the LoRA is trained for FLUX and `prefix=None` is set.")
|
| 208 |
|
| 209 |
-
|
| 210 |
pipe.enable_vae_slicing()
|
| 211 |
clip_side_prompt = safe_trim_for_clip(prompt, max_words=77)
|
| 212 |
init_image = None
|
| 213 |
error_message = ""
|
|
|
|
|
|
|
| 214 |
try:
|
| 215 |
gr.Info("Start to generate images ...")
|
| 216 |
joint_attention_kwargs = {"scale": 1}
|
|
@@ -246,25 +242,25 @@ def run_lora(
|
|
| 246 |
progress(100, "Completed!")
|
| 247 |
|
| 248 |
# CHANGED: Return both image AND json
|
| 249 |
-
return image, json.dumps(result)
|
| 250 |
-
|
| 251 |
-
# Gradio interface
|
| 252 |
-
|
| 253 |
|
| 254 |
|
|
|
|
| 255 |
with gr.Blocks() as demo:
|
| 256 |
-
gr.Markdown("flux-dev-multi-lora")
|
| 257 |
with gr.Row():
|
| 258 |
|
| 259 |
with gr.Column():
|
| 260 |
-
|
| 261 |
prompt = gr.Text(label="Prompt", placeholder="Enter prompt", lines=10)
|
| 262 |
-
lora_strings_json = gr.Text(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
image_url = gr.Text(label="Image url", placeholder="Enter image url to enable image to image model", lines=1)
|
| 264 |
run_button = gr.Button("Run", scale=0)
|
| 265 |
|
| 266 |
with gr.Accordion("Advanced Settings", open=False):
|
| 267 |
-
|
| 268 |
with gr.Row():
|
| 269 |
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
|
| 270 |
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
|
@@ -284,14 +280,16 @@ with gr.Blocks() as demo:
|
|
| 284 |
secret_key = gr.Textbox(label="Secret Key", placeholder="Enter R2 secret key here")
|
| 285 |
bucket = gr.Textbox(label="Bucket Name", placeholder="Enter R2 bucket name here")
|
| 286 |
|
| 287 |
-
|
| 288 |
with gr.Column():
|
|
|
|
|
|
|
| 289 |
json_text = gr.Text(label="Result JSON")
|
| 290 |
|
| 291 |
gr.Markdown("**Disclaimer:**")
|
| 292 |
gr.Markdown(
|
| 293 |
"This demo is only for research purpose. This space owner cannot be held responsible for the generation of NSFW (Not Safe For Work) content through the use of this demo. Users are solely responsible for any content they create, and it is their obligation to ensure that it adheres to appropriate and ethical standards. This space owner provides the tools, but the responsibility for their use lies with the individual user."
|
| 294 |
)
|
|
|
|
| 295 |
inputs = [
|
| 296 |
prompt,
|
| 297 |
image_url,
|
|
@@ -310,7 +308,8 @@ with gr.Blocks() as demo:
|
|
| 310 |
bucket
|
| 311 |
]
|
| 312 |
|
| 313 |
-
|
|
|
|
| 314 |
|
| 315 |
run_button.click(
|
| 316 |
fn=run_lora,
|
|
@@ -321,4 +320,4 @@ with gr.Blocks() as demo:
|
|
| 321 |
try:
|
| 322 |
demo.queue().launch()
|
| 323 |
except:
|
| 324 |
-
|
|
|
|
| 9 |
from diffusers import DiffusionPipeline, AutoencoderTiny, AutoencoderKL, AutoPipelineForImage2Image
|
| 10 |
from huggingface_hub import login
|
| 11 |
from diffusers.utils import load_image
|
|
|
|
| 12 |
import time
|
| 13 |
from datetime import datetime
|
| 14 |
from io import BytesIO
|
|
|
|
| 26 |
from diffusers.quantizers import PipelineQuantizationConfig
|
| 27 |
from diffusers import (FluxPriorReduxPipeline, FluxInpaintPipeline, FluxFillPipeline, FluxKontextPipeline, FluxPipeline)
|
| 28 |
|
| 29 |
+
|
| 30 |
# Login Hugging Face Hub
|
| 31 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
| 32 |
login(token=HF_TOKEN)
|
| 33 |
import diffusers
|
| 34 |
|
| 35 |
+
|
| 36 |
# init
|
| 37 |
dtype = torch.bfloat16
|
| 38 |
device = "cuda:0"
|
| 39 |
|
| 40 |
+
|
| 41 |
base_model = "black-forest-labs/FLUX.1-Krea-dev"
|
| 42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
txt2img_pipe = FluxKontextPipeline.from_pretrained(base_model, torch_dtype=dtype)
|
| 45 |
txt2img_pipe = txt2img_pipe.to(device)
|
| 46 |
|
| 47 |
+
|
| 48 |
MAX_SEED = 2**32 - 1
|
| 49 |
|
| 50 |
+
|
| 51 |
class calculateDuration:
|
| 52 |
def __init__(self, activity_name=""):
|
| 53 |
self.activity_name = activity_name
|
|
|
|
| 68 |
else:
|
| 69 |
print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
|
| 70 |
|
| 71 |
+
|
| 72 |
def safe_trim_for_clip(text: str, max_words: int = 77) -> str:
|
|
|
|
| 73 |
tokens = re.split(r"\s+", text.strip())
|
| 74 |
if len(tokens) <= max_words:
|
| 75 |
return text
|
| 76 |
return " ".join(tokens[:max_words])
|
| 77 |
|
| 78 |
+
|
| 79 |
def upload_image_to_r2(image, account_id, access_key, secret_key, bucket_name):
|
| 80 |
with calculateDuration("Upload images"):
|
| 81 |
connectionUrl = f"https://{account_id}.r2.cloudflarestorage.com"
|
|
|
|
| 94 |
buffer.seek(0)
|
| 95 |
s3.upload_fileobj(buffer, bucket_name, image_file)
|
| 96 |
print("upload finish", image_file)
|
| 97 |
+
|
| 98 |
+
# Generate thumbnail
|
| 99 |
thumbnail = image.copy()
|
| 100 |
thumbnail_width = 256
|
| 101 |
aspect_ratio = image.height / image.width
|
| 102 |
thumbnail_height = int(thumbnail_width * aspect_ratio)
|
| 103 |
thumbnail = thumbnail.resize((thumbnail_width, thumbnail_height), Image.LANCZOS)
|
| 104 |
|
|
|
|
| 105 |
thumbnail_file = image_file.replace(".png", "_thumbnail.png")
|
| 106 |
|
|
|
|
| 107 |
thumbnail_buffer = BytesIO()
|
| 108 |
thumbnail.save(thumbnail_buffer, "PNG")
|
| 109 |
thumbnail_buffer.seek(0)
|
|
|
|
| 111 |
print("upload thumbnail finish", thumbnail_file)
|
| 112 |
return image_file
|
| 113 |
|
| 114 |
+
|
| 115 |
def generate_random_4_digit_string():
|
| 116 |
return ''.join(random.choices(string.digits, k=4))
|
| 117 |
|
| 118 |
+
|
| 119 |
@spaces.GPU(duration=120)
|
| 120 |
def run_lora(
|
| 121 |
prompt,
|
|
|
|
| 153 |
try:
|
| 154 |
pipe.unload_lora_weights()
|
| 155 |
except Exception as _:
|
|
|
|
| 156 |
pass
|
| 157 |
|
| 158 |
adapter_names = []
|
|
|
|
| 170 |
for lora_info in lora_configs:
|
| 171 |
repo = lora_info.get("repo")
|
| 172 |
weights = lora_info.get("weights")
|
|
|
|
| 173 |
adapter_name = lora_info.get("adapter_name") or f"adp_{generate_random_4_digit_string()}"
|
| 174 |
weight = float(lora_info.get("adapter_weight", 1.0))
|
| 175 |
if not (repo and weights):
|
|
|
|
| 177 |
continue
|
| 178 |
try:
|
| 179 |
weight_path = hf_hub_download(repo_id=repo, filename=weights)
|
|
|
|
| 180 |
pipe.load_lora_weights(weight_path, adapter_name=adapter_name, prefix=None)
|
| 181 |
adapter_names.append(adapter_name)
|
| 182 |
adapter_weights.append(weight)
|
|
|
|
| 185 |
|
| 186 |
if adapter_names:
|
| 187 |
pipe.set_adapters(adapter_names, adapter_weights=adapter_weights)
|
|
|
|
|
|
|
| 188 |
try:
|
| 189 |
active = pipe.get_active_adapters() if hasattr(pipe, "get_active_adapters") else []
|
| 190 |
print("Active adapters:", active)
|
|
|
|
| 198 |
lora_layer_count += 1
|
| 199 |
print(f"[DEBUG] transformer LoRA layers: {lora_layer_count}")
|
| 200 |
|
|
|
|
| 201 |
if lora_layer_count == 0:
|
| 202 |
gr.Warning("LoRA seems not injected (0 layers on transformer). Check whether the LoRA is trained for FLUX and `prefix=None` is set.")
|
| 203 |
|
|
|
|
| 204 |
pipe.enable_vae_slicing()
|
| 205 |
clip_side_prompt = safe_trim_for_clip(prompt, max_words=77)
|
| 206 |
init_image = None
|
| 207 |
error_message = ""
|
| 208 |
+
image = None
|
| 209 |
+
|
| 210 |
try:
|
| 211 |
gr.Info("Start to generate images ...")
|
| 212 |
joint_attention_kwargs = {"scale": 1}
|
|
|
|
| 242 |
progress(100, "Completed!")
|
| 243 |
|
| 244 |
# CHANGED: Return both image AND json
|
| 245 |
+
return image, json.dumps(result)
|
|
|
|
|
|
|
|
|
|
| 246 |
|
| 247 |
|
| 248 |
+
# Gradio interface
|
| 249 |
with gr.Blocks() as demo:
|
| 250 |
+
gr.Markdown("# flux-dev-multi-lora")
|
| 251 |
with gr.Row():
|
| 252 |
|
| 253 |
with gr.Column():
|
|
|
|
| 254 |
prompt = gr.Text(label="Prompt", placeholder="Enter prompt", lines=10)
|
| 255 |
+
lora_strings_json = gr.Text(
|
| 256 |
+
label="LoRA Configs (JSON List String)",
|
| 257 |
+
placeholder='[{"repo": "lora_repo1", "weights": "weights1.safetensors", "adapter_name": "adapter1", "adapter_weight": 1.0}]',
|
| 258 |
+
lines=5
|
| 259 |
+
)
|
| 260 |
image_url = gr.Text(label="Image url", placeholder="Enter image url to enable image to image model", lines=1)
|
| 261 |
run_button = gr.Button("Run", scale=0)
|
| 262 |
|
| 263 |
with gr.Accordion("Advanced Settings", open=False):
|
|
|
|
| 264 |
with gr.Row():
|
| 265 |
seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, randomize=True)
|
| 266 |
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
|
|
|
| 280 |
secret_key = gr.Textbox(label="Secret Key", placeholder="Enter R2 secret key here")
|
| 281 |
bucket = gr.Textbox(label="Bucket Name", placeholder="Enter R2 bucket name here")
|
| 282 |
|
|
|
|
| 283 |
with gr.Column():
|
| 284 |
+
# CHANGED: Add image output
|
| 285 |
+
output_image = gr.Image(label="Generated Image", type="pil")
|
| 286 |
json_text = gr.Text(label="Result JSON")
|
| 287 |
|
| 288 |
gr.Markdown("**Disclaimer:**")
|
| 289 |
gr.Markdown(
|
| 290 |
"This demo is only for research purpose. This space owner cannot be held responsible for the generation of NSFW (Not Safe For Work) content through the use of this demo. Users are solely responsible for any content they create, and it is their obligation to ensure that it adheres to appropriate and ethical standards. This space owner provides the tools, but the responsibility for their use lies with the individual user."
|
| 291 |
)
|
| 292 |
+
|
| 293 |
inputs = [
|
| 294 |
prompt,
|
| 295 |
image_url,
|
|
|
|
| 308 |
bucket
|
| 309 |
]
|
| 310 |
|
| 311 |
+
# CHANGED: Two outputs now
|
| 312 |
+
outputs = [output_image, json_text]
|
| 313 |
|
| 314 |
run_button.click(
|
| 315 |
fn=run_lora,
|
|
|
|
| 320 |
try:
|
| 321 |
demo.queue().launch()
|
| 322 |
except:
|
| 323 |
+
print("demo exception ...")
|