Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python | |
| import datetime | |
| import json | |
| import os | |
| import pathlib | |
| import shutil | |
| import tempfile | |
| import uuid | |
| from typing import Any | |
| import gradio as gr | |
| from gradio_client import Client | |
| from scheduler import ZipScheduler | |
| HF_TOKEN = os.getenv('HF_TOKEN') | |
| UPLOAD_REPO_ID = os.getenv('UPLOAD_REPO_ID') | |
| UPLOAD_FREQUENCY = int(os.getenv('UPLOAD_FREQUENCY', '5')) | |
| USE_PUBLIC_REPO = os.getenv('USE_PUBLIC_REPO') == '1' | |
| LOCAL_SAVE_DIR = pathlib.Path(os.getenv('LOCAL_SAVE_DIR', 'results')) | |
| LOCAL_SAVE_DIR.mkdir(parents=True, exist_ok=True) | |
| scheduler = ZipScheduler(repo_id=UPLOAD_REPO_ID, | |
| repo_type='dataset', | |
| every=UPLOAD_FREQUENCY, | |
| private=not USE_PUBLIC_REPO, | |
| token=HF_TOKEN, | |
| folder_path=LOCAL_SAVE_DIR) | |
| client = Client('stabilityai/stable-diffusion') | |
| def generate(prompt: str) -> tuple[str, list[str]]: | |
| negative_prompt = '' | |
| guidance_scale = 9 | |
| out_dir = client.predict(prompt, | |
| negative_prompt, | |
| guidance_scale, | |
| fn_index=1) | |
| config = { | |
| 'prompt': prompt, | |
| 'negative_prompt': negative_prompt, | |
| 'guidance_scale': guidance_scale, | |
| } | |
| config_file = tempfile.NamedTemporaryFile(mode='w', | |
| suffix='.json', | |
| delete=False) | |
| json.dump(config, config_file) | |
| with open(pathlib.Path(out_dir) / 'captions.json') as f: | |
| paths = list(json.load(f).keys()) | |
| return config_file.name, paths | |
| def get_selected_index(evt: gr.SelectData) -> int: | |
| return evt.index | |
| def save_preference(config_path: str, gallery: list[dict[str, Any]], | |
| selected_index: int) -> None: | |
| save_dir = LOCAL_SAVE_DIR / f'{uuid.uuid4()}' | |
| save_dir.mkdir(parents=True, exist_ok=True) | |
| paths = [x['name'] for x in gallery] | |
| with scheduler.lock: | |
| for index, path in enumerate(paths): | |
| ext = pathlib.Path(path).suffix | |
| shutil.move(path, save_dir / f'{index:03d}{ext}') | |
| with open(config_path) as f: | |
| config = json.load(f) | |
| json_path = save_dir / 'preferences.json' | |
| with json_path.open('w') as f: | |
| preferences = config | { | |
| 'selected_index': selected_index, | |
| 'timestamp': datetime.datetime.utcnow().isoformat(), | |
| } | |
| json.dump(preferences, f) | |
| def clear() -> tuple[dict, dict, dict]: | |
| return ( | |
| gr.update(value=None), | |
| gr.update(value=None), | |
| gr.update(interactive=False), | |
| ) | |
| with gr.Blocks(css='style.css') as demo: | |
| with gr.Group(): | |
| prompt = gr.Text(show_label=False, placeholder='Prompt') | |
| config_path = gr.Text(visible=False) | |
| gallery = gr.Gallery(show_label=False).style(columns=2, | |
| rows=2, | |
| height='600px', | |
| object_fit='scale-down') | |
| selected_index = gr.Number(visible=False, precision=0) | |
| save_preference_button = gr.Button('Save preference', interactive=False) | |
| prompt.submit( | |
| fn=generate, | |
| inputs=prompt, | |
| outputs=[config_path, gallery], | |
| ).success( | |
| fn=lambda: gr.update(interactive=True), | |
| outputs=save_preference_button, | |
| queue=False, | |
| ) | |
| gallery.select( | |
| fn=get_selected_index, | |
| outputs=selected_index, | |
| queue=False, | |
| ) | |
| save_preference_button.click( | |
| fn=save_preference, | |
| inputs=[config_path, gallery, selected_index], | |
| queue=False, | |
| ).then( | |
| fn=clear, | |
| outputs=[config_path, gallery, save_preference_button], | |
| queue=False, | |
| ) | |
| demo.queue(concurrency_count=5).launch() | |