AmirMoris's picture
fix
147c862
raw
history blame
5.31 kB
import os
import math
import gradio as gr
from Helper_functions import *
from Kaggle_API import API_Connection
from GoogleDrive_API import GoogleDrive_API
DEFAULT_VALUES = {
"input_image": None,
"edit_instruction": "",
"steps": 100,
"randomize_seed": "Fix Seed",
"seed": 1371,
"randomize_cfg": "Fix CFG",
"text_cfg_scale": 7.5,
"image_cfg_scale": 1.5,
"resolution": 512,
"edited_image": None
}
def generate_button_clicked(*args):
# set kaggle-api variables
kaggle_username = os.environ["kaggle_username"]
kaggle_key = os.environ["kaggle_key"]
input_keys = list(DEFAULT_VALUES.keys())
values = dict(zip(input_keys, list(args)))
for key in values:
if values[key] is None:
values[key] = DEFAULT_VALUES[key]
if values["randomize_seed"]:
values["randomize_seed"] = random.randint(1, 100000)
if values["randomize_cfg"]:
values["text_cfg_scale"] = round(random.uniform(6.0, 9.0), ndigits=2)
values["image_cfg_scale"] = round(random.uniform(1.2, 1.8), ndigits=2)
# parameters for the model
input_image = values["input_image"]
edit_instruction = values["edit_instruction"]
steps = values["steps"]
seed = values["seed"]
cfgtext = values["text_cfg_scale"]
cfgimage = values["image_cfg_scale"]
resolution = 2 ** int(math.log2(values["resolution"]))
if input_image is None:
raise gr.Error("Missing Input: input_image")
if len(edit_instruction) == 0:
raise gr.Error("Missing Input: edit_instruction")
GoogleDrive_connection = GoogleDrive_API("service_account.json")
api_connection = API_Connection(GoogleDrive_connection, kaggle_username, kaggle_key)
create_folder("local_dataset")
image_ID = get_random_str(4)
input_image_name = rf"input_image_{image_ID}.png"
output_image_name = rf"output_image_{image_ID}.png"
input_image.save(rf"local_dataset\{input_image_name}")
status, img = api_connection.generate_image(
input_image_name, edit_instruction, output_image_name,
steps, seed, cfgtext, cfgimage, resolution
)
print(rf"End Time : {get_current_time()}")
if not status:
raise gr.Error(img)
return img
def reset_button_clicked():
return list(DEFAULT_VALUES.values())
def main():
with gr.Blocks(theme="AmirMoris/GP_Themes") as demo:
toggle_theme = gr.Button(value="Toggle Theme")
with gr.Row():
input_image = gr.Image(label="Input Image", type="pil", interactive=True)
edited_image = gr.Image(
label=f"Edited Image", type="pil", interactive=False
)
with gr.Row():
with gr.Column(scale=3):
instruction = gr.Textbox(
lines=1, label="Edit Instruction", interactive=True
)
with gr.Column(scale=1, min_width=100):
with gr.Row():
generate_button = gr.Button("Generate")
with gr.Row():
reset_button = gr.Button("Reset")
with gr.Row():
steps = gr.Number(value=DEFAULT_VALUES["steps"], precision=0, label="Steps", interactive=True)
randomize_seed = gr.Radio(
["Fix Seed", "Randomize Seed"],
value=DEFAULT_VALUES["randomize_seed"],
type="index",
show_label=False,
interactive=True,
)
seed = gr.Number(value=DEFAULT_VALUES["seed"], precision=0, label="Seed", interactive=True)
randomize_cfg = gr.Radio(
["Fix CFG", "Randomize CFG"],
value=DEFAULT_VALUES["randomize_cfg"],
type="index",
show_label=False,
interactive=True,
)
text_cfg_scale = gr.Number(value=DEFAULT_VALUES["text_cfg_scale"], label=f"Text CFG", interactive=True)
image_cfg_scale = gr.Number(value=DEFAULT_VALUES["image_cfg_scale"], label=f"Image CFG", interactive=True)
resolution = gr.Number(value=DEFAULT_VALUES["resolution"], label=f"Resolution", interactive=True)
generate_button.click(
fn=generate_button_clicked,
inputs=[
input_image,
instruction,
steps,
randomize_seed,
seed,
randomize_cfg,
text_cfg_scale,
image_cfg_scale,
resolution
],
outputs=edited_image,
)
reset_button.click(
fn=reset_button_clicked,
outputs=[
input_image,
instruction,
steps,
randomize_seed,
seed,
randomize_cfg,
text_cfg_scale,
image_cfg_scale,
resolution,
edited_image
],
)
toggle_theme.click(
None,
js=
"""
() => {
document.body.classList.toggle('dark');
}
""",
)
# Launch Gradio interface
demo.queue(max_size=1)
demo.launch(share=True)
if __name__ == "__main__":
main()