#!/usr/bin/env python import datetime import json import os import pathlib import shutil import tempfile import uuid from typing import Any import gradio as gr from gradio_client import Client from scheduler import ZipScheduler HF_TOKEN = os.getenv('HF_TOKEN') UPLOAD_REPO_ID = os.getenv('UPLOAD_REPO_ID') UPLOAD_FREQUENCY = int(os.getenv('UPLOAD_FREQUENCY', '5')) USE_PUBLIC_REPO = os.getenv('USE_PUBLIC_REPO') == '1' LOCAL_SAVE_DIR = pathlib.Path(os.getenv('LOCAL_SAVE_DIR', 'results')) LOCAL_SAVE_DIR.mkdir(parents=True, exist_ok=True) scheduler = ZipScheduler(repo_id=UPLOAD_REPO_ID, repo_type='dataset', every=UPLOAD_FREQUENCY, private=not USE_PUBLIC_REPO, token=HF_TOKEN, folder_path=LOCAL_SAVE_DIR) client = Client('stabilityai/stable-diffusion') def generate(prompt: str) -> tuple[str, list[str]]: negative_prompt = '' guidance_scale = 9 out_dir = client.predict(prompt, negative_prompt, guidance_scale, fn_index=1) config = { 'prompt': prompt, 'negative_prompt': negative_prompt, 'guidance_scale': guidance_scale, } config_file = tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) json.dump(config, config_file) with open(pathlib.Path(out_dir) / 'captions.json') as f: paths = list(json.load(f).keys()) return config_file.name, paths def get_selected_index(evt: gr.SelectData) -> int: return evt.index def save_preference(config_path: str, gallery: list[dict[str, Any]], selected_index: int) -> None: save_dir = LOCAL_SAVE_DIR / f'{uuid.uuid4()}' save_dir.mkdir(parents=True, exist_ok=True) paths = [x['name'] for x in gallery] with scheduler.lock: for index, path in enumerate(paths): ext = pathlib.Path(path).suffix shutil.move(path, save_dir / f'{index:03d}{ext}') with open(config_path) as f: config = json.load(f) json_path = save_dir / 'preferences.json' with json_path.open('w') as f: preferences = config | { 'selected_index': selected_index, 'timestamp': datetime.datetime.utcnow().isoformat(), } json.dump(preferences, f) def clear() -> tuple[dict, dict, dict]: return ( gr.update(value=None), gr.update(value=None), gr.update(interactive=False), ) with gr.Blocks(css='style.css') as demo: with gr.Group(): prompt = gr.Text(show_label=False, placeholder='Prompt') config_path = gr.Text(visible=False) gallery = gr.Gallery(show_label=False).style(columns=2, rows=2, height='600px', object_fit='scale-down') selected_index = gr.Number(visible=False, precision=0) save_preference_button = gr.Button('Save preference', interactive=False) prompt.submit( fn=generate, inputs=prompt, outputs=[config_path, gallery], ).success( fn=lambda: gr.update(interactive=True), outputs=save_preference_button, queue=False, ) gallery.select( fn=get_selected_index, outputs=selected_index, queue=False, ) save_preference_button.click( fn=save_preference, inputs=[config_path, gallery, selected_index], queue=False, ).then( fn=clear, outputs=[config_path, gallery, save_preference_button], queue=False, ) demo.queue(concurrency_count=5).launch()