File size: 4,903 Bytes
75c1e5a
 
 
 
 
 
 
 
 
 
 
 
 
a93f005
75c1e5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a93f005
75c1e5a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import spaces
import argparse
import os
import time
from os import path
from safetensors.torch import load_file
from huggingface_hub import hf_hub_download

cache_path = path.join(path.dirname(path.abspath(__file__)), "models")
os.environ["TRANSFORMERS_CACHE"] = cache_path
os.environ["HF_HUB_CACHE"] = cache_path
os.environ["HF_HOME"] = cache_path

import gradio as gr
import torch
from diffusers import FluxPipeline

torch.backends.cuda.matmul.allow_tf32 = True

class timer:
    def __init__(self, method_name="timed process"):
        self.method = method_name
    def __enter__(self):
        self.start = time.time()
        print(f"{self.method} starts")
    def __exit__(self, exc_type, exc_val, exc_tb):
        end = time.time()
        print(f"{self.method} took {str(round(end - self.start, 2))}s")

if not path.exists(cache_path):
    os.makedirs(cache_path, exist_ok=True)

pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
pipe.load_lora_weights(hf_hub_download("MatthiasBachfischer/open-engineering-orcas", "open-engineering-orcas.safetensors"))
pipe.fuse_lora(lora_scale=1.0)
pipe.to(device="cuda", dtype=torch.bfloat16)

theme = gr.themes.Base(
    primary_hue=gr.themes.Color(c100="#f4e5dc", c200="#f6c1b0", c300="#f59a86", c400="#f05b48", c50="#fef2f2", c500="#ea1b0a", c600="#c41708", c700="#9d1207", c800="#991b1b", c900="#7f1d1d", c950="#6c1e1e"),
    font=[gr.themes.GoogleFont('Arial'), 'ui-sans-serif', 'system-ui', 'sans-serif'],
).set(
    button_primary_background_fill='*primary_500',
    button_primary_text_color='*neutral_50'
)

with gr.Blocks(theme=theme) as demo:
    gr.Markdown(
        """
        <div style="text-align: center; max-width: 900px; margin: 0 auto;">
            <h1 style="font-size: 2.5rem; font-weight: 700; margin-bottom: 1rem; display: contents;">E.ON Open Engineering Orcas</h1>
            <p style="font-size: 1rem; margin-bottom: 1.5rem;">This space hosts a fine-tuned <a href="https://huggingface.co/black-forest-labs/FLUX.1-dev">FLUX.1 dev</a> LoRA model to create <a href="https://github.com/jansche/open-engineering-orcas">Open Engineering Orca mascots</a>.</p>
        </div>
        """
    )

    with gr.Row():
        with gr.Column(scale=3):
            with gr.Group():
                prompt = gr.Textbox(
                    label="Your orca description",
                    placeholder="E.g., orca with a backpack",
                    lines=3
                )

                with gr.Accordion("Advanced Settings", open=False):
                    with gr.Group():
                        with gr.Row():
                            height = gr.Slider(label="Height", minimum=256, maximum=1152, step=64, value=1024)
                            width = gr.Slider(label="Width", minimum=256, maximum=1152, step=64, value=1024)
                        
                        with gr.Row():
                            steps = gr.Slider(label="Inference Steps", minimum=6, maximum=25, step=1, value=8)
                            scales = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=5.0, step=0.1, value=3.5)
                        
                        seed = gr.Number(label="Seed (for reproducibility)", value=3413, precision=0)
                
                generate_btn = gr.Button("Generate Orca", variant="primary", scale=1)

        with gr.Column(scale=4):
            output = gr.Image(label="Your Generated Image")
    
    gr.Markdown(
        """
        <div style="max-width: 650px; margin: 2rem auto; padding: 1rem; border-radius: 10px; background-color: #f0f0f0;">
            <h2 style="font-size: 1.5rem; margin-bottom: 1rem;">How to Use</h2>
            <ol style="padding-left: 1.5rem;">
                <li>Enter a detailed description of the orca you want to create.</li>
                <li>Adjust advanced settings if desired (tap to expand).</li>
                <li>Tap "Generate Image" and wait for your creation!</li>
            </ol>
            <p style="margin-top: 1rem; font-style: italic;">Tip: Be specific in your description for best results!</p>
        </div>
        """
    )

    @spaces.GPU
    def process_image(height, width, steps, scales, prompt, seed):
        global pipe
        with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16), timer("inference"):
            return pipe(
                prompt=[prompt],
                generator=torch.Generator().manual_seed(int(seed)),
                num_inference_steps=int(steps),
                guidance_scale=float(scales),
                height=int(height),
                width=int(width),
                max_sequence_length=256
            ).images[0]

    generate_btn.click(
        process_image,
        inputs=[height, width, steps, scales, prompt, seed],
        outputs=output
    )

if __name__ == "__main__":
    demo.launch()