Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -18,38 +18,26 @@ MAX_SEED = np.iinfo(np.int32).max
|
|
| 18 |
|
| 19 |
pipe = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16).to("cuda")
|
| 20 |
|
| 21 |
-
# Load LoRA data
|
| 22 |
-
|
| 23 |
-
with open("flux_loras.json", "r") as file:
|
| 24 |
data = json.load(file)
|
|
|
|
| 25 |
flux_loras_raw = [
|
| 26 |
{
|
| 27 |
"image": item["image"],
|
| 28 |
"title": item["title"],
|
| 29 |
"repo": item["repo"],
|
| 30 |
-
"trigger_word": item.get("trigger_word", ""),
|
| 31 |
-
"trigger_position": item.get("trigger_position", "prepend"),
|
| 32 |
"weights": item.get("weights", "pytorch_lora_weights.safetensors"),
|
|
|
|
|
|
|
|
|
|
| 33 |
"lora_type": item.get("lora_type", "flux"),
|
| 34 |
-
"lora_scale_config": item.get("lora_scale", 1.
|
| 35 |
-
"prompt_placeholder": item.get("prompt_placeholder", ""),
|
| 36 |
}
|
| 37 |
for item in data
|
| 38 |
]
|
| 39 |
-
print(f"Loaded {len(flux_loras_raw)} LoRAs from
|
| 40 |
-
# Global variables for LoRA management
|
| 41 |
-
lora_cache = {}
|
| 42 |
-
|
| 43 |
-
def load_lora_weights(repo_id, weights_filename):
|
| 44 |
-
"""Load LoRA weights from HuggingFace"""
|
| 45 |
-
try:
|
| 46 |
-
if repo_id not in lora_cache:
|
| 47 |
-
lora_path = hf_hub_download(repo_id=repo_id, filename=weights_filename)
|
| 48 |
-
lora_cache[repo_id] = lora_path
|
| 49 |
-
return lora_cache[repo_id]
|
| 50 |
-
except Exception as e:
|
| 51 |
-
print(f"Error loading LoRA from {repo_id}: {e}")
|
| 52 |
-
return None
|
| 53 |
|
| 54 |
def update_selection(selected_state: gr.SelectData, flux_loras):
|
| 55 |
"""Update UI when a LoRA is selected"""
|
|
@@ -57,148 +45,73 @@ def update_selection(selected_state: gr.SelectData, flux_loras):
|
|
| 57 |
return "### No LoRA selected", gr.update(), None, gr.update()
|
| 58 |
|
| 59 |
lora_repo = flux_loras[selected_state.index]["repo"]
|
| 60 |
-
trigger_word = flux_loras[selected_state.index]["trigger_word"]
|
| 61 |
|
| 62 |
updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo})"
|
| 63 |
config_placeholder = flux_loras[selected_state.index]["prompt_placeholder"]
|
| 64 |
-
if config_placeholder:
|
| 65 |
-
new_placeholder = config_placeholder
|
| 66 |
-
else:
|
| 67 |
-
new_placeholder = f"opt - describe the person/subject, e.g. 'a man with glasses and a beard'"
|
| 68 |
|
| 69 |
-
|
| 70 |
-
|
| 71 |
print("Optimal Scale: ", optimal_scale)
|
| 72 |
-
return updated_text, gr.update(placeholder=
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
def get_huggingface_lora(link):
|
| 76 |
-
"""Download LoRA from HuggingFace link"""
|
| 77 |
-
split_link = link.split("/")
|
| 78 |
-
if len(split_link) == 2:
|
| 79 |
-
try:
|
| 80 |
-
model_card = ModelCard.load(link)
|
| 81 |
-
trigger_word = model_card.data.get("instance_prompt", "")
|
| 82 |
-
|
| 83 |
-
fs = HfFileSystem()
|
| 84 |
-
list_of_files = fs.ls(link, detail=False)
|
| 85 |
-
safetensors_file = None
|
| 86 |
-
|
| 87 |
-
for file in list_of_files:
|
| 88 |
-
if file.endswith(".safetensors") and "lora" in file.lower():
|
| 89 |
-
safetensors_file = file.split("/")[-1]
|
| 90 |
-
break
|
| 91 |
-
|
| 92 |
-
if not safetensors_file:
|
| 93 |
-
safetensors_file = "pytorch_lora_weights.safetensors"
|
| 94 |
-
|
| 95 |
-
return split_link[1], safetensors_file, trigger_word
|
| 96 |
-
except Exception as e:
|
| 97 |
-
raise Exception(f"Error loading LoRA: {e}")
|
| 98 |
-
else:
|
| 99 |
-
raise Exception("Invalid HuggingFace repository format")
|
| 100 |
-
|
| 101 |
-
def load_custom_lora(link):
|
| 102 |
-
"""Load custom LoRA from user input"""
|
| 103 |
-
if not link:
|
| 104 |
-
return gr.update(visible=False), "", gr.update(visible=False), None, gr.Gallery(selected_index=None), "### Click on a LoRA in the gallery to select it", None
|
| 105 |
-
|
| 106 |
-
try:
|
| 107 |
-
repo_name, weights_file, trigger_word = get_huggingface_lora(link)
|
| 108 |
-
|
| 109 |
-
card = f'''
|
| 110 |
-
<div style="border: 1px solid #ddd; padding: 10px; border-radius: 8px; margin: 10px 0;">
|
| 111 |
-
<span><strong>Loaded custom LoRA:</strong></span>
|
| 112 |
-
<div style="margin-top: 8px;">
|
| 113 |
-
<h4>{repo_name}</h4>
|
| 114 |
-
<small>{"Using: <code><b>"+trigger_word+"</b></code> as trigger word" if trigger_word else "No trigger word found"}</small>
|
| 115 |
-
</div>
|
| 116 |
-
</div>
|
| 117 |
-
'''
|
| 118 |
-
|
| 119 |
-
custom_lora_data = {
|
| 120 |
-
"repo": link,
|
| 121 |
-
"weights": weights_file,
|
| 122 |
-
"trigger_word": trigger_word
|
| 123 |
-
}
|
| 124 |
-
|
| 125 |
-
return gr.update(visible=True), card, gr.update(visible=True), custom_lora_data, gr.Gallery(selected_index=None), f"Custom: {repo_name}", None
|
| 126 |
-
|
| 127 |
-
except Exception as e:
|
| 128 |
-
return gr.update(visible=True), f"Error: {str(e)}", gr.update(visible=False), None, gr.update(), "### Click on a LoRA in the gallery to select it", None
|
| 129 |
-
|
| 130 |
-
def remove_custom_lora():
|
| 131 |
-
"""Remove custom LoRA"""
|
| 132 |
-
return "", gr.update(visible=False), gr.update(visible=False), None, None
|
| 133 |
-
|
| 134 |
-
def classify_gallery(flux_loras):
|
| 135 |
-
"""Sort gallery by likes"""
|
| 136 |
-
sorted_gallery = sorted(flux_loras, key=lambda x: x.get("likes", 0), reverse=True)
|
| 137 |
-
return [(item["image"], item["title"]) for item in sorted_gallery], sorted_gallery
|
| 138 |
|
|
|
|
| 139 |
def infer_with_lora_wrapper(input_image, prompt, selected_index, lora_state, custom_lora, seed=42, randomize_seed=False, guidance_scale=2.5, lora_scale=1.75,portrait_mode=False, flux_loras=None, progress=gr.Progress(track_tqdm=True)):
|
| 140 |
"""Wrapper function to handle state serialization"""
|
| 141 |
-
|
|
|
|
| 142 |
|
| 143 |
@spaces.GPU
|
| 144 |
-
def infer_with_lora(input_image, prompt, selected_index,
|
| 145 |
"""Generate image with selected LoRA"""
|
| 146 |
global pipe
|
| 147 |
|
| 148 |
if randomize_seed:
|
| 149 |
seed = random.randint(0, MAX_SEED)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
|
| 151 |
-
# Determine which LoRA to use
|
| 152 |
lora_to_use = None
|
| 153 |
-
if
|
| 154 |
-
lora_to_use = custom_lora
|
| 155 |
-
elif selected_index is not None and flux_loras and selected_index < len(flux_loras):
|
| 156 |
lora_to_use = flux_loras[selected_index]
|
| 157 |
-
|
| 158 |
-
# Load LoRA if needed
|
| 159 |
-
print(f"LoRA to use: {lora_to_use}")
|
| 160 |
if lora_to_use:
|
|
|
|
| 161 |
try:
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
except Exception as e:
|
| 172 |
print(f"Error loading LoRA: {e}")
|
| 173 |
-
|
| 174 |
-
|
|
|
|
|
|
|
|
|
|
| 175 |
input_image = input_image.convert("RGB")
|
| 176 |
-
|
| 177 |
-
trigger_word = lora_to_use["trigger_word"]
|
| 178 |
-
is_kontext_lora = lora_to_use["lora_type"] == "kontext"
|
| 179 |
-
if not is_kontext_lora:
|
| 180 |
-
if portrait_mode:
|
| 181 |
-
if trigger_word == ", How2Draw":
|
| 182 |
-
prompt = f"create a How2Draw sketch of the person of the photo {prompt}, maintain the facial identity of the person and general features"
|
| 183 |
-
elif trigger_word == ", video game screenshot in the style of THSMS":
|
| 184 |
-
prompt = f"create a video game screenshot in the style of THSMS with the person from the photo, {prompt}. maintain the facial identity of the person and general features"
|
| 185 |
-
else:
|
| 186 |
-
prompt = f"convert the style of this portrait photo to {trigger_word} while maintaining the identity of the person. {prompt}. Make sure to maintain the person's facial identity and features, while still changing the overall style to {trigger_word}."
|
| 187 |
-
else:
|
| 188 |
-
if trigger_word == ", How2Draw":
|
| 189 |
-
prompt = f"create a How2Draw sketch of the photo {prompt}"
|
| 190 |
-
elif trigger_word == ", video game screenshot in the style of THSMS":
|
| 191 |
-
prompt = f"create a video game screenshot in the style of THSMS of the photo, {prompt}."
|
| 192 |
-
else:
|
| 193 |
-
prompt = f"convert the style of this photo {prompt} to {trigger_word}."
|
| 194 |
-
else:
|
| 195 |
-
prompt = f"{trigger_word}. {prompt}."
|
| 196 |
try:
|
| 197 |
image = pipe(
|
| 198 |
image=input_image,
|
| 199 |
width=input_image.size[0],
|
| 200 |
height=input_image.size[1],
|
| 201 |
-
prompt=
|
| 202 |
guidance_scale=guidance_scale,
|
| 203 |
generator=torch.Generator().manual_seed(seed)
|
| 204 |
).images[0]
|
|
@@ -253,12 +166,13 @@ with gr.Blocks(css=css, theme=gr.themes.Ocean(font=[gr.themes.GoogleFont("Lexend
|
|
| 253 |
gr_flux_loras = gr.State(value=flux_loras_raw)
|
| 254 |
|
| 255 |
title = gr.HTML(
|
| 256 |
-
"""<h1><img src="https://huggingface.co/spaces/kontext-community/FLUX.1-Kontext-portrait/resolve/main/dora_kontext.png" alt="LoRA">
|
| 257 |
elem_id="title",
|
| 258 |
)
|
| 259 |
-
gr.Markdown("
|
| 260 |
|
| 261 |
selected_state = gr.State(value=None)
|
|
|
|
| 262 |
custom_loaded_lora = gr.State(value=None)
|
| 263 |
lora_state = gr.State(value=1.0)
|
| 264 |
|
|
@@ -331,16 +245,7 @@ with gr.Blocks(css=css, theme=gr.themes.Ocean(font=[gr.themes.GoogleFont("Lexend
|
|
| 331 |
)
|
| 332 |
|
| 333 |
# Event handlers
|
| 334 |
-
|
| 335 |
-
fn=load_custom_lora,
|
| 336 |
-
inputs=[custom_model],
|
| 337 |
-
outputs=[custom_model_card, custom_model_card, custom_model_button, custom_loaded_lora, gallery, prompt_title, selected_state],
|
| 338 |
-
)
|
| 339 |
-
|
| 340 |
-
custom_model_button.click(
|
| 341 |
-
fn=remove_custom_lora,
|
| 342 |
-
outputs=[custom_model, custom_model_button, custom_model_card, custom_loaded_lora, selected_state]
|
| 343 |
-
)
|
| 344 |
|
| 345 |
gallery.select(
|
| 346 |
fn=update_selection,
|
|
@@ -364,8 +269,7 @@ with gr.Blocks(css=css, theme=gr.themes.Ocean(font=[gr.themes.GoogleFont("Lexend
|
|
| 364 |
|
| 365 |
# Initialize gallery
|
| 366 |
demo.load(
|
| 367 |
-
fn=
|
| 368 |
-
inputs=[gr_flux_loras],
|
| 369 |
outputs=[gallery, gr_flux_loras]
|
| 370 |
)
|
| 371 |
|
|
|
|
| 18 |
|
| 19 |
pipe = FluxKontextPipeline.from_pretrained("black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16).to("cuda")
|
| 20 |
|
| 21 |
+
# Load LoRA data from our custom JSON file
|
| 22 |
+
with open("kontext_loras.json", "r") as file:
|
|
|
|
| 23 |
data = json.load(file)
|
| 24 |
+
# Add default values for keys that might be missing, to prevent errors
|
| 25 |
flux_loras_raw = [
|
| 26 |
{
|
| 27 |
"image": item["image"],
|
| 28 |
"title": item["title"],
|
| 29 |
"repo": item["repo"],
|
|
|
|
|
|
|
| 30 |
"weights": item.get("weights", "pytorch_lora_weights.safetensors"),
|
| 31 |
+
# The following keys are kept for compatibility with the original demo structure,
|
| 32 |
+
# but our simplified logic doesn't heavily rely on them.
|
| 33 |
+
"trigger_word": item.get("trigger_word", ""),
|
| 34 |
"lora_type": item.get("lora_type", "flux"),
|
| 35 |
+
"lora_scale_config": item.get("lora_scale", 1.0), # Default scale set to 1.0
|
| 36 |
+
"prompt_placeholder": item.get("prompt_placeholder", "Describe the subject..."),
|
| 37 |
}
|
| 38 |
for item in data
|
| 39 |
]
|
| 40 |
+
print(f"Loaded {len(flux_loras_raw)} LoRAs from kontext_loras.json")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
def update_selection(selected_state: gr.SelectData, flux_loras):
|
| 43 |
"""Update UI when a LoRA is selected"""
|
|
|
|
| 45 |
return "### No LoRA selected", gr.update(), None, gr.update()
|
| 46 |
|
| 47 |
lora_repo = flux_loras[selected_state.index]["repo"]
|
|
|
|
| 48 |
|
| 49 |
updated_text = f"### Selected: [{lora_repo}](https://huggingface.co/{lora_repo})"
|
| 50 |
config_placeholder = flux_loras[selected_state.index]["prompt_placeholder"]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
+
optimal_scale = flux_loras[selected_state.index].get("lora_scale_config", 1.0)
|
| 53 |
+
print("Selected Style: ", flux_loras[selected_state.index]['title'])
|
| 54 |
print("Optimal Scale: ", optimal_scale)
|
| 55 |
+
return updated_text, gr.update(placeholder=config_placeholder), selected_state.index, optimal_scale
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
+
# This wrapper is kept for compatibility with the Gradio event triggers
|
| 58 |
def infer_with_lora_wrapper(input_image, prompt, selected_index, lora_state, custom_lora, seed=42, randomize_seed=False, guidance_scale=2.5, lora_scale=1.75,portrait_mode=False, flux_loras=None, progress=gr.Progress(track_tqdm=True)):
|
| 59 |
"""Wrapper function to handle state serialization"""
|
| 60 |
+
# The 'custom_lora' and 'lora_state' arguments are no longer used but kept in the signature
|
| 61 |
+
return infer_with_lora(input_image, prompt, selected_index, seed, randomize_seed, guidance_scale, lora_scale, portrait_mode, flux_loras, progress)
|
| 62 |
|
| 63 |
@spaces.GPU
|
| 64 |
+
def infer_with_lora(input_image, prompt, selected_index, seed=42, randomize_seed=False, guidance_scale=2.5, lora_scale=1.0, portrait_mode=False, flux_loras=None, progress=gr.Progress(track_tqdm=True)):
|
| 65 |
"""Generate image with selected LoRA"""
|
| 66 |
global pipe
|
| 67 |
|
| 68 |
if randomize_seed:
|
| 69 |
seed = random.randint(0, MAX_SEED)
|
| 70 |
+
|
| 71 |
+
# Unload any previous LoRA to ensure a clean state
|
| 72 |
+
if "selected_lora" in pipe.get_active_adapters():
|
| 73 |
+
pipe.unload_lora_weights()
|
| 74 |
|
| 75 |
+
# Determine which LoRA to use from our gallery
|
| 76 |
lora_to_use = None
|
| 77 |
+
if selected_index is not None and flux_loras and selected_index < len(flux_loras):
|
|
|
|
|
|
|
| 78 |
lora_to_use = flux_loras[selected_index]
|
| 79 |
+
|
|
|
|
|
|
|
| 80 |
if lora_to_use:
|
| 81 |
+
print(f"Applying LoRA: {lora_to_use['title']}")
|
| 82 |
try:
|
| 83 |
+
# Load LoRA directly from the Hugging Face Hub
|
| 84 |
+
pipe.load_lora_weights(
|
| 85 |
+
lora_to_use["repo"],
|
| 86 |
+
weight_name=lora_to_use["weights"],
|
| 87 |
+
adapter_name="selected_lora"
|
| 88 |
+
)
|
| 89 |
+
pipe.set_adapters(["selected_lora"], adapter_weights=[lora_scale])
|
| 90 |
+
print(f"Loaded {lora_to_use['repo']} with scale {lora_scale}")
|
| 91 |
|
| 92 |
+
# Simplified and direct prompt construction
|
| 93 |
+
style_name = lora_to_use['title']
|
| 94 |
+
if prompt:
|
| 95 |
+
final_prompt = f"Turn this image of {prompt} into {style_name} style."
|
| 96 |
+
else:
|
| 97 |
+
final_prompt = f"Turn this image into {style_name} style."
|
| 98 |
+
print(f"Using prompt: {final_prompt}")
|
| 99 |
+
|
| 100 |
except Exception as e:
|
| 101 |
print(f"Error loading LoRA: {e}")
|
| 102 |
+
final_prompt = prompt # Fallback to user prompt if LoRA fails
|
| 103 |
+
else:
|
| 104 |
+
# No LoRA selected, just use the original prompt
|
| 105 |
+
final_prompt = prompt
|
| 106 |
+
|
| 107 |
input_image = input_image.convert("RGB")
|
| 108 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
try:
|
| 110 |
image = pipe(
|
| 111 |
image=input_image,
|
| 112 |
width=input_image.size[0],
|
| 113 |
height=input_image.size[1],
|
| 114 |
+
prompt=final_prompt,
|
| 115 |
guidance_scale=guidance_scale,
|
| 116 |
generator=torch.Generator().manual_seed(seed)
|
| 117 |
).images[0]
|
|
|
|
| 166 |
gr_flux_loras = gr.State(value=flux_loras_raw)
|
| 167 |
|
| 168 |
title = gr.HTML(
|
| 169 |
+
"""<h1><img src="https://huggingface.co/spaces/kontext-community/FLUX.1-Kontext-portrait/resolve/main/dora_kontext.png" alt="LoRA"> Kontext-Style LoRA Explorer</h1>""",
|
| 170 |
elem_id="title",
|
| 171 |
)
|
| 172 |
+
gr.Markdown("A demo for the style LoRAs from the [Kontext-Style Collection](https://huggingface.co/Kontext-Style) 🤗")
|
| 173 |
|
| 174 |
selected_state = gr.State(value=None)
|
| 175 |
+
# The following states are no longer used by the simplified logic but kept for component structure
|
| 176 |
custom_loaded_lora = gr.State(value=None)
|
| 177 |
lora_state = gr.State(value=1.0)
|
| 178 |
|
|
|
|
| 245 |
)
|
| 246 |
|
| 247 |
# Event handlers
|
| 248 |
+
# The custom model inputs are no longer needed as we've hidden them.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 249 |
|
| 250 |
gallery.select(
|
| 251 |
fn=update_selection,
|
|
|
|
| 269 |
|
| 270 |
# Initialize gallery
|
| 271 |
demo.load(
|
| 272 |
+
fn=lambda: (flux_loras_raw, flux_loras_raw),
|
|
|
|
| 273 |
outputs=[gallery, gr_flux_loras]
|
| 274 |
)
|
| 275 |
|