Plat
feat: update deps, add zero gpu progress
31c0e08
raw
history blame
6.56 kB
import spaces
import os
import random
from PIL import Image
import torch
import gradio as gr
import dotenv
from adapter import load_ip_adapter_model, get_file_path
from example import EXAMPLES
dotenv.load_dotenv(".env.local")
ADAPTER_REPO_ID = os.environ.get("ADAPTER_REPO_ID")
ADAPTER_MODEL_PATH = os.environ.get("ADAPTER_MODEL_PATH")
ADAPTER_CONFIG_PATH = os.environ.get("ADAPTER_CONFIG_PATH")
assert ADAPTER_REPO_ID is not None
assert ADAPTER_MODEL_PATH is not None
assert ADAPTER_CONFIG_PATH is not None
BASE_MODEL_REPO_ID = os.environ.get(
"BASE_MODEL_REPO_ID", "p1atdev/animagine-xl-4.0-bnb-nf4"
)
BASE_MODEL_PATH = os.environ.get(
"BASE_MODEL_PATH", "animagine-xl-4.0-opt.bnb_nf4.safetensors"
)
INITIAL_BATCH_SIZE = int(os.environ.get("INITIAL_BATCH_SIZE", 1))
adapter_model_path = get_file_path(ADAPTER_REPO_ID, ADAPTER_MODEL_PATH)
adapter_config_path = get_file_path(ADAPTER_REPO_ID, ADAPTER_CONFIG_PATH)
base_model_path = get_file_path(BASE_MODEL_REPO_ID, BASE_MODEL_PATH)
model = load_ip_adapter_model(
model_path=base_model_path,
config_path=adapter_config_path,
adapter_path=adapter_model_path,
)
@spaces.GPU
def on_generate(
prompt: str,
negative_prompt: str,
image: Image.Image | None,
width: int,
height: int,
steps: int,
cfg_scale: float,
seed: int,
randomize_seed: bool = True,
num_images: int = 4,
progress=gr.Progress(track_tqdm=True),
):
if image is not None:
image = image.convert("RGB")
if randomize_seed:
seed = random.randint(0, 2147483647)
with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
model.to("cuda:0")
images = model.generate(
prompt=[prompt] * num_images, # batch size 4
negative_prompt=negative_prompt,
reference_image=image,
num_inference_steps=steps,
cfg_scale=cfg_scale,
width=width,
height=height,
seed=seed,
do_offloading=False,
device="cuda:0",
max_token_length=225,
execution_dtype=torch.bfloat16,
)
torch.cuda.empty_cache()
return images, seed
def main():
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
prompt = gr.TextArea(
label="Prompt",
value="masterpiece, best quality",
placeholder="masterpiece, best quality",
interactive=True,
)
input_image = gr.Image(
label="Reference Image",
type="pil",
height=600,
)
with gr.Accordion("Negative Prompt", open=False):
negative_prompt = gr.TextArea(
label="Negative Prompt",
show_label=False,
value="lowres, bad anatomy, bad hands, text, error, missing finger, extra digits, fewer digits, cropped, worst quality, low quality, low score, bad score, average score, signature, watermark, username, blurry",
interactive=True,
)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=2048,
step=128,
value=896,
interactive=True,
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=2048,
step=128,
value=1152,
interactive=True,
)
with gr.Accordion("Advanced options", open=False):
num_images = gr.Slider(
label="Number of images to generate",
minimum=1,
maximum=8,
step=1,
value=INITIAL_BATCH_SIZE,
interactive=True,
)
with gr.Row():
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=2147483647,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(
label="Randomize seed",
value=True,
interactive=True,
scale=1,
)
steps = gr.Slider(
label="Inference steps",
minimum=10,
maximum=50,
step=1,
value=25,
interactive=True,
)
cfg_scale = gr.Slider(
label="CFG scale",
minimum=3.0,
maximum=8.0,
step=0.5,
value=5.0,
interactive=True,
)
with gr.Column():
generate_button = gr.Button(
"Generate",
variant="primary",
)
output_image = gr.Gallery(
label="Generated images",
type="pil",
rows=2,
height="768px",
preview=True,
show_label=True,
)
comment = gr.Markdown(
label="Comment",
visible=False,
)
gr.Examples(
examples=EXAMPLES,
inputs=[input_image, prompt, width, height, comment],
cache_examples=False,
)
gr.on(
triggers=[generate_button.click],
fn=on_generate,
inputs=[
prompt,
negative_prompt,
input_image,
width,
height,
steps,
cfg_scale,
seed,
randomize_seed,
num_images,
],
outputs=[output_image, seed],
)
demo.launch()
if __name__ == "__main__":
main()