import os
import dataclasses
import functools
import json
import traceback
from pathlib import Path
from dataclasses import dataclass, asdict
from typing import Optional
import gradio as gr
from PIL import Image
import modules.scripts
from modules import script_callbacks
from modules.ui_components import FormGroup, FormRow
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, get_fixed_seed
from scripts.fabric_utils import WebUiComponents, image_hash
from scripts.patching import patch_unet_forward_pass, unpatch_unet_forward_pass
# Compatibility with WebUI v1.3.0 and earlier versions
try:
# WebUI v1.4.0+
from modules.ui_common import create_refresh_button
except ImportError:
# Earlier versions
from modules.ui import create_refresh_button
__version__ = "0.6.6"
DEBUG = os.getenv("DEBUG", "false").lower() in ("true", "1")
OUTPUT_PATH = "log/fabric/images"
PRESET_PATH = "log/fabric/presets"
if DEBUG:
print(f"WARNING: Loading FABRIC v{__version__} in DEBUG mode")
else:
print(f"Loading FABRIC v{__version__}")
"""
# Gradio 3.32 bug fix
Fixes FileNotFoundError when displaying PIL images in Gradio Gallery.
"""
import tempfile
gradio_tempfile_path = os.path.join(tempfile.gettempdir(), 'gradio')
os.makedirs(gradio_tempfile_path, exist_ok=True)
def use_feedback(params):
if not params.enabled:
return False
if params.start >= params.end and params.min_weight <= 0:
return False
if params.max_weight <= 0:
return False
if params.neg_scale <= 0 and len(params.pos_images) == 0:
return False
if len(params.pos_images) == 0 and len(params.neg_images) == 0:
return False
return True
def save_feedback_image(img, filename=None, base_path=OUTPUT_PATH):
if filename is None:
filename = image_hash(img) + ".png"
img_path = Path(modules.scripts.basedir(), base_path, filename)
img_path.parent.mkdir(parents=True, exist_ok=True)
img.save(img_path)
return filename
@functools.lru_cache(maxsize=128)
def load_feedback_image(filename, base_path=OUTPUT_PATH):
img_path = Path(modules.scripts.basedir(), base_path, filename)
return Image.open(img_path)
def full_image_path(filename, base_path=OUTPUT_PATH):
img_path = Path(modules.scripts.basedir(), base_path, filename)
return str(img_path)
# helper functions for loading saved params
def _load_feedback_paths(d, key):
try:
paths = json.loads(d.get(key, "[]").replace("'", '"'))
except Exception as e:
traceback.print_exc()
print(d)
print(f"Failed to load feedback images: {d.get(key, '[]')}")
paths = []
paths = [path for path in paths if os.path.exists(full_image_path(path))]
return paths
def _load_gallery(d, key):
paths = _load_feedback_paths(d, key)
return [full_image_path(path) for path in paths]
def _save_preset(preset_name, liked_paths, disliked_paths, base_path=PRESET_PATH):
preset_path = Path(modules.scripts.basedir(), base_path, f"{preset_name}.json")
preset_path.parent.mkdir(parents=True, exist_ok=True)
preset = {
"liked_paths": liked_paths,
"disliked_paths": disliked_paths,
}
with open(preset_path, "w") as f:
json.dump(preset, f, indent=4)
def _load_presets(base_path=PRESET_PATH):
presets_path = Path(modules.scripts.basedir(), base_path)
presets_path.mkdir(parents=True, exist_ok=True)
presets = [preset.stem for preset in presets_path.iterdir() if preset.is_file() and preset.suffix == ".json"]
return presets
@dataclass
class FabricParams:
enabled: bool = True
start: float = 0.0
end: float = 0.8
min_weight: float = 0.0
max_weight: float = 0.8
neg_scale: float = 0.5
pos_images: list = dataclasses.field(default_factory=list)
neg_images: list = dataclasses.field(default_factory=list)
pos_latents: Optional[list] = None
neg_latents: Optional[list] = None
pos_latent_cache: Optional[dict] = None
neg_latent_cache: Optional[dict] = None
feedback_during_high_res_fix: bool = False
tome_enabled: bool = False
tome_ratio: float = 0.5
tome_max_tokens: int = 4*4096
tome_seed: int = -1
burnout_protection: bool = False
# TODO: replace global state with Gradio state
class FabricState:
txt2img_images = []
img2img_images = []
class FabricScript(modules.scripts.Script):
def __init__(self) -> None:
super().__init__()
def title(self):
return "FABRIC"
def show(self, is_img2img):
return modules.scripts.AlwaysVisible
def ui(self, is_img2img):
self.txt2img_selected_image = gr.State(None)
self.img2img_selected_image = gr.State(None)
selected_like = gr.State(None)
selected_dislike = gr.State(None)
# need to use JSON over State to make it compatible with gr.update
liked_paths = gr.JSON(value=[], visible=False)
disliked_paths = gr.JSON(value=[], visible=False)
with gr.Accordion(f"{self.title()} v{__version__}", open=DEBUG, elem_id="fabric"):
with FormGroup():
with FormRow():
feedback_enabled = gr.Checkbox(label="Enable", value=False)
feedback_during_high_res_fix = gr.Checkbox(label="Enable during hires. fix", value=False)
with gr.Row():
presets_list = gr.Dropdown(label="Presets", choices=_load_presets(), default=None, live=False)
create_refresh_button(presets_list, lambda: None, lambda: {"choices": _load_presets()}, "fabric_reload_presets_btn")
with gr.Tabs():
with gr.Tab("Current batch"):
# TODO: figure out why the display is shared between tabs
self.img2img_selected_display = gr.Image(value=None, type="pil", label="Selected image", visible=is_img2img, height=256)
self.txt2img_selected_display = gr.Image(value=None, type="pil", label="Selected image", visible=not is_img2img, height=256)
with gr.Row():
like_btn_selected = gr.Button("👍 Like")
dislike_btn_selected = gr.Button("👎 Dislike")
with gr.Tab("Upload image"):
upload_img_input = gr.Image(type="pil", label="Upload image", height=256)
with gr.Row():
like_btn_uploaded = gr.Button("👍 Like")
dislike_btn_uploaded = gr.Button("👎 Dislike")
with gr.Tabs(initial_tab="👍 Likes"):
with gr.Tab("👍 Likes"):
with gr.Row():
remove_selected_like_btn = gr.Button("Remove selected", interactive=False)
clear_liked_btn = gr.Button("Clear")
like_gallery = gr.Gallery(label="Liked images", elem_id="fabric_like_gallery", columns=4, height=192)
with gr.Tab("👎 Dislikes"):
with gr.Row():
remove_selected_dislike_btn = gr.Button("Remove selected", interactive=False)
clear_disliked_btn = gr.Button("Clear")
dislike_gallery = gr.Gallery(label="Disliked images", elem_id="fabric_dislike_gallery", columns=4, height=192)
save_preset_btn = gr.Button("Save as preset")
gr.HTML("
")
with FormGroup():
with FormRow():
feedback_start = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.0, label="Feedback start")
feedback_end = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.8, label="Feedback end")
feedback_max_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.8, label="Feedback Strength", elem_id="fabric_max_weight")
tome_enabled = gr.Checkbox(label="Enable Token Merging (faster, less VRAM, less accurate)", value=False)
burnout_protection = gr.Checkbox(label="Burnout protection (enable if results contain artifacts or are especially dark)", value=False)
with gr.Accordion("Advanced options", open=DEBUG):
with FormGroup():
feedback_min_weight = gr.Slider(minimum=0, maximum=1, step=0.05, value=0.0, label="Min. strength", info="Minimum feedback strength at every diffusion step.", elem_id="fabric_min_weight")
feedback_neg_scale = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.5, label="Negative weight", info="Strength of negative feedback relative to positive feedback.", elem_id="fabric_neg_scale")
tome_ratio = gr.Slider(minimum=0.0, maximum=0.75, step=0.125, value=0.5, label="ToMe merge ratio", info="Percentage of tokens to be merged (higher improves speed)", elem_id="fabric_tome_ratio")
tome_max_tokens = gr.Slider(minimum=4096, maximum=16*4096, step=4096, value=2*4096, label="ToMe max. tokens", info="Maximum number of tokens after merging (lower improves VRAM usage)", elem_id="fabric_tome_max_tokens")
tome_seed = gr.Number(label="ToMe seed", value=-1, step=1, info="Random seed for ToMe partition", elem_id="fabric_tome_seed")
WebUiComponents.on_txt2img_gallery(self.register_txt2img_gallery_select)
WebUiComponents.on_img2img_gallery(self.register_img2img_gallery_select)
if is_img2img:
like_btn_selected.click(self.add_image_to_state, inputs=[self.img2img_selected_image, liked_paths], outputs=[like_gallery, liked_paths])
dislike_btn_selected.click(self.add_image_to_state, inputs=[self.img2img_selected_image, disliked_paths], outputs=[dislike_gallery, disliked_paths])
else:
like_btn_selected.click(self.add_image_to_state, inputs=[self.txt2img_selected_image, liked_paths], outputs=[like_gallery, liked_paths])
dislike_btn_selected.click(self.add_image_to_state, inputs=[self.txt2img_selected_image, disliked_paths], outputs=[dislike_gallery, disliked_paths])
like_btn_uploaded.click(self.add_image_to_state, inputs=[upload_img_input, liked_paths], outputs=[like_gallery, liked_paths])
dislike_btn_uploaded.click(self.add_image_to_state, inputs=[upload_img_input, disliked_paths], outputs=[dislike_gallery, disliked_paths])
clear_liked_btn.click(lambda _: ([], [], []), inputs=[], outputs=[like_gallery, liked_paths])
clear_disliked_btn.click(lambda _: ([], [], []), inputs=[], outputs=[dislike_gallery, disliked_paths])
like_gallery.select(
self.select_for_removal,
_js="(a, b) => [a, fabric_selected_gallery_index('fabric_like_gallery')]",
inputs=[like_gallery, like_gallery],
outputs=[selected_like, remove_selected_like_btn],
)
dislike_gallery.select(
self.select_for_removal,
_js="(a, b) => [a, fabric_selected_gallery_index('fabric_dislike_gallery')]",
inputs=[dislike_gallery, dislike_gallery],
outputs=[selected_dislike, remove_selected_dislike_btn],
)
remove_selected_like_btn.click(
self.remove_selected,
inputs=[liked_paths, selected_like],
outputs=[like_gallery, liked_paths, selected_like, remove_selected_like_btn],
)
remove_selected_dislike_btn.click(
self.remove_selected,
inputs=[disliked_paths, selected_dislike],
outputs=[dislike_gallery, disliked_paths, selected_dislike, remove_selected_dislike_btn],
)
save_preset_btn.click(
self.save_preset,
_js="(a, b, c, d) => [a, b, c, prompt('Enter a name for your preset:')]",
inputs=[presets_list, liked_paths, disliked_paths, disliked_paths], # last input is a dummy
outputs=[presets_list],
)
presets_list.input(
self.on_preset_selected,
inputs=[presets_list, liked_paths, disliked_paths],
outputs=[
liked_paths,
disliked_paths,
like_gallery,
dislike_gallery,
],
)
# sets FABRIC params when "send to txt2img/img2img" is clicked
self.infotext_fields = [
(feedback_enabled, lambda d: gr.Checkbox.update(value="fabric_start" in d)),
(feedback_start, "fabric_start"),
(feedback_end, "fabric_end"),
(feedback_min_weight, "fabric_min_weight"),
(feedback_max_weight, "fabric_max_weight"),
(feedback_neg_scale, "fabric_neg_scale"),
(tome_enabled, "fabric_tome_enabled"),
(tome_ratio, "fabric_tome_ratio"),
(tome_max_tokens, "fabric_tome_max_tokens"),
(tome_seed, "fabric_tome_seed"),
(burnout_protection, "fabric_burnout_protection"),
(feedback_during_high_res_fix, "fabric_feedback_during_high_res_fix"),
(liked_paths, lambda d: gr.update(value=_load_feedback_paths(d, "fabric_pos_images")) if "fabric_pos_images" in d else None),
(disliked_paths, lambda d: gr.update(value=_load_feedback_paths(d, "fabric_neg_images")) if "fabric_neg_images" in d else None),
(like_gallery, lambda d: gr.Gallery.update(value=_load_gallery(d, "fabric_pos_images")) if "fabric_pos_images" in d else None),
(dislike_gallery, lambda d: gr.Gallery.update(value=_load_gallery(d, "fabric_neg_images")) if "fabric_neg_images" in d else None),
]
return [
liked_paths,
disliked_paths,
feedback_enabled,
feedback_start,
feedback_end,
feedback_min_weight,
feedback_max_weight,
feedback_neg_scale,
feedback_during_high_res_fix,
tome_enabled,
tome_ratio,
tome_max_tokens,
tome_seed,
burnout_protection,
]
def select_for_removal(self, gallery, selected_idx):
return [
selected_idx,
gr.update(interactive=True),
]
def remove_selected(self, paths, idx):
if idx >= 0 and idx < len(paths):
paths.pop(idx)
gallery = [full_image_path(path) for path in paths]
return [
gallery,
paths,
gr.update(value=None),
gr.update(interactive=False),
]
def add_image_to_state(self, img, paths):
if img is not None:
path = save_feedback_image(img)
paths.append(path)
gallery = [full_image_path(path) for path in paths]
return gallery, paths
def save_preset(self, presets, liked_paths, disliked_paths, preset_name):
if preset_name is not None and preset_name != "":
_save_preset(preset_name, liked_paths, disliked_paths)
return gr.update(choices=_load_presets())
def on_preset_selected(self, preset_name, liked_paths, disliked_paths):
preset_path = Path(modules.scripts.basedir(), PRESET_PATH, f"{preset_name}.json")
if preset_path.exists():
try:
with open(preset_path, "r") as f:
preset = json.load(f)
assert "liked_paths" in preset, "Missing 'liked_paths' in preset"
assert "disliked_paths" in preset, "Missing 'disliked_paths' in preset"
liked_paths = preset["liked_paths"]
disliked_paths = preset["disliked_paths"]
except Exception as e:
traceback.print_exc()
print(f"Failed to load preset: {preset_path}")
like_gallery = [full_image_path(path) for path in liked_paths]
dislike_gallery = [full_image_path(path) for path in disliked_paths]
return liked_paths, disliked_paths, like_gallery, dislike_gallery
def register_txt2img_gallery_select(self, gallery):
self.register_gallery_select(
gallery,
listener=self.on_txt2img_gallery_select,
selected=self.txt2img_selected_image,
display=self.txt2img_selected_display,
)
def register_img2img_gallery_select(self, gallery):
self.register_gallery_select(
gallery,
listener=self.on_img2img_gallery_select,
selected=self.img2img_selected_image,
display=self.img2img_selected_display,
)
def register_gallery_select(self, gallery, listener=None, selected=None, display=None):
gallery.select(
listener,
_js="(a, b) => [a, selected_gallery_index()]",
inputs=[
gallery,
gallery, # can be any Gradio component (but not None), will be overwritten with selected gallery index
],
outputs=[selected, display],
)
def on_txt2img_gallery_select(self, gallery, selected_idx):
return self.on_gallery_select(gallery, selected_idx, FabricState.txt2img_images)
def on_img2img_gallery_select(self, gallery, selected_idx):
return self.on_gallery_select(gallery, selected_idx, FabricState.img2img_images)
def on_gallery_select(self, gallery, selected_idx, images):
idx = selected_idx - (len(gallery) - len(images))
if idx >= 0 and idx < len(images):
return images[idx], gr.update(value=images[idx])
else:
return None, None
def process(self, p, *args):
(
liked_paths,
disliked_paths,
feedback_enabled,
feedback_start,
feedback_end,
feedback_min_weight,
feedback_max_weight,
feedback_neg_scale,
feedback_during_high_res_fix,
tome_enabled,
tome_ratio,
tome_max_tokens,
tome_seed,
burnout_protection,
) = args
# restore original U-Net forward pass in case previous batch errored out
unpatch_unet_forward_pass(p.sd_model.model.diffusion_model)
if not feedback_enabled:
return
likes = [load_feedback_image(path) for path in liked_paths]
dislikes = [load_feedback_image(path) for path in disliked_paths]
params = FabricParams(
enabled=feedback_enabled,
start=feedback_start,
end=feedback_end,
min_weight=feedback_min_weight,
max_weight=feedback_max_weight,
neg_scale=feedback_neg_scale,
pos_images=likes,
neg_images=dislikes,
feedback_during_high_res_fix=feedback_during_high_res_fix,
tome_enabled=tome_enabled,
tome_ratio=(round(tome_ratio * 16) / 16),
tome_max_tokens=tome_max_tokens,
tome_seed=get_fixed_seed(int(tome_seed)),
burnout_protection=burnout_protection,
)
if use_feedback(params) or (DEBUG and feedback_enabled):
print(f"[FABRIC] Patching U-Net forward pass... ({len(likes)} likes, {len(dislikes)} dislikes)")
# log the generation params to be displayed/stored as metadata
log_params = asdict(params)
log_params["pos_images"] = json.dumps(liked_paths)
log_params["neg_images"] = json.dumps(disliked_paths)
del log_params["enabled"]
if not params.tome_enabled:
del log_params["tome_ratio"]
del log_params["tome_max_tokens"]
del log_params["tome_seed"]
log_params = {f"fabric_{k}": v for k, v in log_params.items()}
p.extra_generation_params.update(log_params)
unet = p.sd_model.model.diffusion_model
patch_unet_forward_pass(p, unet, params)
else:
print("[FABRIC] Skipping U-Net forward pass patching")
def postprocess(self, p, processed, *args):
unpatch_unet_forward_pass(p.sd_model.model.diffusion_model)
images = processed.images[processed.index_of_first_image:]
if isinstance(p, StableDiffusionProcessingTxt2Img):
FabricState.txt2img_images = images
elif isinstance(p, StableDiffusionProcessingImg2Img):
FabricState.img2img_images = images
else:
raise RuntimeError(f"Unsupported processing type: {type(p)}")
script_callbacks.on_after_component(WebUiComponents.register_component)