|
|
import gradio as gr |
|
|
import os |
|
|
from src.rec_system.art_rec import ArtRecSystem |
|
|
import queue |
|
|
|
|
|
|
|
|
class UI: |
|
|
def __init__(self): |
|
|
|
|
|
self.image_pipeline = None |
|
|
self.rec_system = None |
|
|
self.output_images = [] |
|
|
self.validation_state = False |
|
|
|
|
|
def build_ui(self): |
|
|
def check_preferences(element_checkboxes, element_preferences): |
|
|
""" |
|
|
Combines selected checkboxes and custom preferences into a single list. |
|
|
Args: |
|
|
element_checkboxes (list): List of selected options from checkboxes. |
|
|
element_preferences (str): Custom preference entered by the user. |
|
|
Returns: |
|
|
list: Combined list of checkboxes and custom preferences. |
|
|
""" |
|
|
return_list = [] |
|
|
if len(element_checkboxes) > 0: |
|
|
return_list = element_checkboxes |
|
|
if len(element_preferences) > 0: |
|
|
return_list.append(element_preferences) |
|
|
return return_list |
|
|
|
|
|
def gars_session_validation(iteration_count): |
|
|
""" |
|
|
Validates the session configuration for the GARS system. |
|
|
Args: |
|
|
iteration_count (int): Number of iterations for the session. |
|
|
Updates: |
|
|
validation_state (bool): Indicates if the validation passed. |
|
|
""" |
|
|
if not isinstance(iteration_count, int): |
|
|
error_message = "Number of Iterations Must Be an Integer!" |
|
|
gr.Warning(error_message) |
|
|
self.validation_state = False |
|
|
elif iteration_count < 10 or iteration_count > 100: |
|
|
error_message = "Number of Iterations Must be Between 10 and 100!" |
|
|
gr.Warning(error_message) |
|
|
self.validation_state = False |
|
|
else: |
|
|
self.validation_state = True |
|
|
|
|
|
def start_gars_session( |
|
|
iteration_count, |
|
|
sdxl_dropdown, |
|
|
subjects_checkboxes, |
|
|
custom_preference_subject, |
|
|
styles_checkboxes, |
|
|
custom_preference_style, |
|
|
art_mediums_checkboxes, |
|
|
custom_preference_medium, |
|
|
progress=gr.Progress(track_tqdm=True), |
|
|
): |
|
|
""" |
|
|
Starts a new GARS (Generative Art Recommendation System) session. |
|
|
Args: |
|
|
iteration_count (int): Number of iterations for generating recommendations. |
|
|
sdxl_dropdown (str): Dropdown choice for the diffusion steps. |
|
|
subjects_checkboxes (list): Selected subjects for recommendations. |
|
|
custom_preference_subject (str): Custom subject preference. |
|
|
styles_checkboxes (list): Selected art styles. |
|
|
custom_preference_style (str): Custom style preference. |
|
|
art_mediums_checkboxes (list): Selected art mediums. |
|
|
custom_preference_medium (str): Custom art medium preference. |
|
|
progress (gr.Progress): Gradio progress tracker. |
|
|
Returns: |
|
|
dict: Updates UI elements based on session initialization. |
|
|
""" |
|
|
self.output_images = [] |
|
|
gars_session_validation(iteration_count) |
|
|
initial_preferences = { |
|
|
"subjects": check_preferences( |
|
|
subjects_checkboxes, custom_preference_subject |
|
|
), |
|
|
"artists_movements": check_preferences( |
|
|
styles_checkboxes, custom_preference_style |
|
|
), |
|
|
"art_mediums": check_preferences( |
|
|
art_mediums_checkboxes, custom_preference_medium |
|
|
), |
|
|
} |
|
|
|
|
|
diffusion_steps = 8 |
|
|
if sdxl_dropdown == "SDXL Lightning [2 Step]": |
|
|
diffusion_steps = 2 |
|
|
elif sdxl_dropdown == "SDXL Lightning [4 Step]": |
|
|
diffusion_steps = 4 |
|
|
progress(0, desc="Starting") |
|
|
self.rec_system = ArtRecSystem( |
|
|
total_iterations=iteration_count, |
|
|
initial_preferences=initial_preferences, |
|
|
diffusion_steps=diffusion_steps, |
|
|
) |
|
|
|
|
|
return { |
|
|
initial_setup: gr.update(visible=False), |
|
|
GARS: gr.update(visible=True), |
|
|
output_image: gr.update(visible=True), |
|
|
|
|
|
output: "", |
|
|
progress_bar: gr.update(visible=False), |
|
|
restart_row: gr.update(visible=True), |
|
|
rating_row: gr.update(visible=True), |
|
|
} |
|
|
|
|
|
def generate_rec( |
|
|
rating, |
|
|
subject_weight, |
|
|
medium_weight, |
|
|
style_weight, |
|
|
modifiers_weight, |
|
|
locked_elements, |
|
|
): |
|
|
""" |
|
|
Generates a new recommendation based on user feedback and preferences. |
|
|
Args: |
|
|
rating (float): User rating for the previous recommendation. |
|
|
subject_weight (float): Weight for subject preference. |
|
|
medium_weight (float): Weight for medium preference. |
|
|
style_weight (float): Weight for style preference. |
|
|
modifiers_weight (float): Weight for modifiers preference. |
|
|
locked_elements (list): Elements to lock in the recommendation. |
|
|
Returns: |
|
|
dict: Updates UI elements with the generated recommendation and gallery. |
|
|
""" |
|
|
|
|
|
rec_comp_map = { |
|
|
"Modifiers": "modifiers", |
|
|
"Medium": "art_mediums", |
|
|
"Style": "artists_movements", |
|
|
"Subject": "subjects", |
|
|
} |
|
|
|
|
|
lock_element_list = ( |
|
|
[rec_comp_map[elem] for elem in locked_elements] |
|
|
if locked_elements |
|
|
else [] |
|
|
) |
|
|
|
|
|
if self.rec_system._iteration == 0: |
|
|
rating = 0 |
|
|
gen_img = self.rec_system( |
|
|
rating=rating, |
|
|
preference_weights=[ |
|
|
modifiers_weight, |
|
|
subject_weight, |
|
|
medium_weight, |
|
|
style_weight, |
|
|
], |
|
|
freeze_elements=lock_element_list, |
|
|
) |
|
|
self.output_images.append(gen_img) |
|
|
self.rec_system.diffusion_pipeline.latent_queue.put(0) |
|
|
row_visibility = not self.rec_system.is_done |
|
|
return { |
|
|
output_image: gen_img, |
|
|
output_gallery: self.output_images, |
|
|
advanced_checkbox_row: gr.update(visible=row_visibility), |
|
|
rating_row: gr.update(visible=row_visibility), |
|
|
rating_wrapper: gr.update(visible=True), |
|
|
iteration_display: gr.update(visible=True), |
|
|
gallery_row: gr.update(visible=not row_visibility), |
|
|
} |
|
|
|
|
|
def show_latent(): |
|
|
""" |
|
|
Streams latent images for testing or live preview. |
|
|
Yields: |
|
|
str: URL of the dummy image if in dummy mode, otherwise streams from the diffusion queue. |
|
|
""" |
|
|
|
|
|
while True: |
|
|
try: |
|
|
image = self.rec_system.diffusion_pipeline.latent_queue.get() |
|
|
|
|
|
if isinstance(image, int): |
|
|
yield self.output_images[-1] |
|
|
break |
|
|
yield image |
|
|
except queue.Empty: |
|
|
print("Queue is empty, retrying...") |
|
|
|
|
|
def update_iteration(): |
|
|
""" |
|
|
Provides current iteration status in the GARS session. |
|
|
Returns: |
|
|
str: Formatted iteration status. |
|
|
""" |
|
|
return f"## Iteration: {self.rec_system._iteration} / {self.rec_system._total_iterations}" |
|
|
|
|
|
def show_advanced(status): |
|
|
""" |
|
|
Toggles the advanced options tab visibility. |
|
|
Args: |
|
|
status (bool): Desired visibility status of the advanced tab. |
|
|
Returns: |
|
|
dict: Updates UI visibility of the advanced tab. |
|
|
""" |
|
|
return {advanced_tab: gr.update(visible=status)} |
|
|
|
|
|
def restart_session(): |
|
|
""" |
|
|
Restarts the GARS session, resetting output images and UI elements. |
|
|
Returns: |
|
|
dict: UI reset to initial setup visibility. |
|
|
""" |
|
|
return { |
|
|
initial_setup: gr.update(visible=True), |
|
|
GARS: gr.update(visible=False), |
|
|
advanced_tab: gr.update(visible=False), |
|
|
advanced_checkbox_row: gr.update(visible=False), |
|
|
rating_wrapper: gr.update(visible=False), |
|
|
output_image: None, |
|
|
output_gallery: gr.update(visible=False), |
|
|
gallery_row: gr.update(visible=False), |
|
|
restart_row: gr.update(visible=False), |
|
|
iteration_display: gr.update(visible=False), |
|
|
} |
|
|
|
|
|
def show_progress(iteration_count): |
|
|
""" |
|
|
Validates iteration count and toggles progress bar visibility accordingly. |
|
|
Args: |
|
|
iteration_count (int): Number of iterations. |
|
|
Returns: |
|
|
dict: UI updates based on validation outcome. |
|
|
""" |
|
|
gars_session_validation(iteration_count) |
|
|
if not self.validation_state: |
|
|
return { |
|
|
progress_bar: gr.update(visible=False), |
|
|
initial_setup: gr.update(visible=True), |
|
|
} |
|
|
return { |
|
|
progress_bar: gr.update(visible=True), |
|
|
initial_setup: gr.update(visible=False), |
|
|
} |
|
|
|
|
|
def show_gallery(): |
|
|
""" |
|
|
Displays the output gallery and hides other UI elements. |
|
|
Returns: |
|
|
dict: UI updates to show gallery view. |
|
|
""" |
|
|
return { |
|
|
output_gallery: gr.update(visible=True), |
|
|
gallery_row: gr.update(visible=False), |
|
|
output_image: gr.update(visible=False), |
|
|
advanced_checkbox: gr.update(visible=False), |
|
|
advanced_tab: gr.update(visible=False), |
|
|
} |
|
|
|
|
|
green_custom = gr.themes.utils.colors.Color( |
|
|
name="green_custom", |
|
|
c50="#e0ff00", |
|
|
c100="#c8ff00", |
|
|
c200="#b1ff00", |
|
|
c300="#99f000", |
|
|
c400="#81cb00", |
|
|
c500="#76b900", |
|
|
c600="#6aa600", |
|
|
c700="#528100", |
|
|
c800="#3b5c00", |
|
|
c900="#233700", |
|
|
c950="#0b1200", |
|
|
) |
|
|
with open(os.path.join("ui.css"), "r") as f: |
|
|
css = f.read() |
|
|
theme = gr.themes.Base( |
|
|
primary_hue=green_custom, |
|
|
secondary_hue=green_custom, |
|
|
neutral_hue="stone", |
|
|
) |
|
|
with gr.Blocks(theme=theme, css=css) as demo: |
|
|
with gr.Row(): |
|
|
with gr.Column("initial setup wrapper") as initial_setup: |
|
|
with gr.Tab("Initial Setup", visible=True): |
|
|
iteration_count = gr.Slider( |
|
|
label="Iteration Count", |
|
|
value=15, |
|
|
minimum=10, |
|
|
maximum=100, |
|
|
step=1, |
|
|
) |
|
|
with gr.Accordion( |
|
|
"Advanced Preferences (optional)", open=False |
|
|
): |
|
|
gr.Markdown("Selection Preferences") |
|
|
subjects_checkboxes = gr.CheckboxGroup( |
|
|
[ |
|
|
"Animals", |
|
|
"Landscapes", |
|
|
"Space", |
|
|
"Oceans", |
|
|
"Forests", |
|
|
"Mountains", |
|
|
"Rivers", |
|
|
"Deserts", |
|
|
"Urban Life", |
|
|
"Fantasy Creatures", |
|
|
"Mythology", |
|
|
"Architecture", |
|
|
"Cityscapes", |
|
|
"Flowers", |
|
|
"Sunsets", |
|
|
"Underwater Scenes", |
|
|
"Winter Scenes", |
|
|
"Autumn Forests", |
|
|
"Portraits", |
|
|
"Historical Scenes", |
|
|
"Abstract Concepts", |
|
|
"Still Life", |
|
|
"Vehicles", |
|
|
"Technology", |
|
|
"Sports", |
|
|
"Music", |
|
|
"Food", |
|
|
"Fashion", |
|
|
"Travel", |
|
|
], |
|
|
label="Subjects", |
|
|
) |
|
|
custom_preference_subject = gr.Textbox( |
|
|
show_label=False, placeholder="Custom Subject" |
|
|
) |
|
|
art_mediums_checkboxes = gr.CheckboxGroup( |
|
|
[ |
|
|
"Digital Art", |
|
|
"Painting", |
|
|
"Sculpture", |
|
|
"Photography", |
|
|
"Ceramics", |
|
|
"Woodworking", |
|
|
"Textiles", |
|
|
"Glass Art", |
|
|
"Metalwork", |
|
|
"Printmaking", |
|
|
], |
|
|
label="Mediums", |
|
|
) |
|
|
custom_preference_medium = gr.Textbox( |
|
|
show_label=False, placeholder="Custom Medium" |
|
|
) |
|
|
styles_checkboxes = gr.CheckboxGroup( |
|
|
[ |
|
|
"Impressionism", |
|
|
"Renaissance", |
|
|
"Baroque", |
|
|
"Modern Art", |
|
|
"Pop Art", |
|
|
"Abstract Art", |
|
|
"Surrealism", |
|
|
"Cubism", |
|
|
"Expressionism", |
|
|
"Minimalism", |
|
|
], |
|
|
label="Styles", |
|
|
) |
|
|
custom_preference_style = gr.Textbox( |
|
|
show_label=False, placeholder="Custom Style" |
|
|
) |
|
|
sdxl_dropdown = gr.Dropdown( |
|
|
[ |
|
|
"SDXL Lightning [2 Step]", |
|
|
"SDXL Lightning [4 Step]", |
|
|
"SDXL Lightning [8 Step]", |
|
|
], |
|
|
label="Model", |
|
|
value="SDXL Lightning [8 Step]", |
|
|
interactive=True, |
|
|
) |
|
|
submit_btn = gr.Button("Submit", elem_id="submit-button") |
|
|
with gr.Column("Out", visible=False) as progress_bar: |
|
|
with gr.Tab("Loading Model...", visible=True): |
|
|
output = gr.Textbox( |
|
|
label="Loading Model...", |
|
|
placeholder="Waiting on preference", |
|
|
visible=True, |
|
|
) |
|
|
with gr.Column("GARS", visible=False) as GARS: |
|
|
with gr.Tab("GARS"): |
|
|
iteration_display = gr.Markdown("", visible=False) |
|
|
with gr.Row( |
|
|
visible=True, elem_id="start-over-button" |
|
|
) as restart_row: |
|
|
restart_btn = gr.Button("Start Over", scale=0) |
|
|
output_image = gr.Image( |
|
|
streaming=True, label="Output Image", visible=True |
|
|
) |
|
|
output_gallery = gr.Gallery( |
|
|
label="Generated images", |
|
|
show_label=False, |
|
|
elem_id="gallery", |
|
|
object_fit="contain", |
|
|
height="auto", |
|
|
visible=False, |
|
|
) |
|
|
with gr.Row(visible=True) as rating_row: |
|
|
with gr.Column(visible=False) as rating_wrapper: |
|
|
rating = gr.Slider( |
|
|
-1, |
|
|
1, |
|
|
value=0, |
|
|
label="Rating", |
|
|
minimum=-1, |
|
|
maximum=1, |
|
|
scale=3 |
|
|
) |
|
|
with gr.Column(visible=True): |
|
|
generate_btn = gr.Button( |
|
|
"Generate", scale=1, elem_id="generate-button" |
|
|
) |
|
|
with gr.Row(visible=False) as gallery_row: |
|
|
gallery_submit = gr.Button( |
|
|
"Show Gallery", elem_id="show-gallery-button" |
|
|
) |
|
|
with gr.Column("Settings", visible=False) as advanced_tab: |
|
|
with gr.Tab("Advanced Options"): |
|
|
gr.Markdown(" Lock Elements") |
|
|
locked_elements = gr.CheckboxGroup( |
|
|
["Subject", "Medium", "Style", "Modifiers"], |
|
|
show_label=False, |
|
|
) |
|
|
gr.Markdown(" Adjust Element Weights") |
|
|
with gr.Group(): |
|
|
subject_weight = gr.Slider( |
|
|
0, |
|
|
1, |
|
|
value=1, |
|
|
label="Subject", |
|
|
minimum=0, |
|
|
maximum=1, |
|
|
interactive=True, |
|
|
) |
|
|
medium_weight = gr.Slider( |
|
|
0, |
|
|
1, |
|
|
value=1, |
|
|
label="Medium", |
|
|
minimum=0, |
|
|
maximum=1, |
|
|
interactive=True, |
|
|
) |
|
|
style_weight = gr.Slider( |
|
|
0, |
|
|
1, |
|
|
value=1, |
|
|
label="Style", |
|
|
minimum=0, |
|
|
maximum=1, |
|
|
interactive=True, |
|
|
) |
|
|
modifiers_weight = gr.Slider( |
|
|
0, |
|
|
1, |
|
|
value=1, |
|
|
label="Modifiers", |
|
|
minimum=0, |
|
|
maximum=1, |
|
|
interactive=True, |
|
|
) |
|
|
with gr.Row(visible=False) as advanced_checkbox_row: |
|
|
advanced_checkbox = gr.Checkbox( |
|
|
label="Advanced", interactive=True, container=False |
|
|
) |
|
|
advanced_checkbox.change( |
|
|
fn=show_advanced, inputs=advanced_checkbox, outputs=advanced_tab |
|
|
) |
|
|
submit_btn.click( |
|
|
fn=show_progress, |
|
|
inputs=[iteration_count], |
|
|
outputs=[progress_bar, initial_setup], |
|
|
) |
|
|
submit_btn.click( |
|
|
fn=start_gars_session, |
|
|
inputs=[ |
|
|
iteration_count, |
|
|
sdxl_dropdown, |
|
|
subjects_checkboxes, |
|
|
custom_preference_subject, |
|
|
styles_checkboxes, |
|
|
custom_preference_style, |
|
|
art_mediums_checkboxes, |
|
|
custom_preference_medium, |
|
|
], |
|
|
outputs=[ |
|
|
initial_setup, |
|
|
GARS, |
|
|
output_image, |
|
|
iteration_display, |
|
|
output_image, |
|
|
output, |
|
|
progress_bar, |
|
|
restart_row, |
|
|
rating_row, |
|
|
], |
|
|
) |
|
|
generate_btn.click(fn=show_latent, outputs=[output_image]) |
|
|
generate_btn.click( |
|
|
fn=generate_rec, |
|
|
inputs=[ |
|
|
rating, |
|
|
subject_weight, |
|
|
medium_weight, |
|
|
style_weight, |
|
|
modifiers_weight, |
|
|
locked_elements, |
|
|
], |
|
|
|
|
|
outputs=[output_image, output_gallery, advanced_checkbox_row, rating_row, rating_wrapper, gallery_row, iteration_display], |
|
|
) |
|
|
restart_btn.click( |
|
|
fn=restart_session, |
|
|
outputs=[ |
|
|
initial_setup, |
|
|
GARS, |
|
|
advanced_checkbox, |
|
|
advanced_tab, |
|
|
output_image, |
|
|
output_gallery, |
|
|
gallery_row, |
|
|
rating_wrapper, |
|
|
advanced_checkbox_row, |
|
|
iteration_display, |
|
|
restart_row, |
|
|
], |
|
|
) |
|
|
output_image.change(fn=update_iteration, outputs=[iteration_display]) |
|
|
|
|
|
gallery_submit.click( |
|
|
fn=show_gallery, |
|
|
outputs=[ |
|
|
gallery_row, |
|
|
output_gallery, |
|
|
output_image, |
|
|
advanced_checkbox, |
|
|
advanced_tab, |
|
|
], |
|
|
) |
|
|
|
|
|
demo.launch() |
|
|
|
|
|
|
|
|
UI().build_ui() |
|
|
|