File size: 5,767 Bytes
dd83dfb
2385a9b
dd83dfb
 
 
2385a9b
 
dd83dfb
2385a9b
dd83dfb
2385a9b
 
dd83dfb
2385a9b
dd83dfb
 
 
 
2385a9b
dd83dfb
2385a9b
dd83dfb
2385a9b
 
 
 
dd83dfb
2385a9b
 
 
 
 
 
 
 
 
 
dd83dfb
2385a9b
 
 
 
 
 
 
 
 
 
 
 
 
 
dd83dfb
 
 
 
 
2385a9b
 
 
 
 
 
 
 
dd83dfb
 
 
2385a9b
 
 
 
dd83dfb
2385a9b
dd83dfb
 
 
 
 
 
2385a9b
 
dd83dfb
 
 
 
 
 
 
2385a9b
dd83dfb
2385a9b
 
 
 
 
 
dd83dfb
 
 
 
2385a9b
 
dd83dfb
 
 
2385a9b
 
dd83dfb
 
2385a9b
 
 
 
 
 
 
 
 
 
 
 
 
dd83dfb
 
2385a9b
 
 
 
 
 
 
 
 
 
 
 
 
 
dd83dfb
2385a9b
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import gradio as gr
import pandas as pd
import numpy as np
import random
import torch
from transformers import pipeline
from diffusers import DiffusionPipeline

# Initialize the device for running the diffusion model
device = "cuda" if torch.cuda.is_available() else "cpu"
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024

# Set up the diffusion pipeline
if torch.cuda.is_available():
    torch.cuda.max_memory_allocated(device=device)
    pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16", use_safetensors=True)
    pipe.enable_xformers_memory_efficient_attention()
else:
    pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", use_safetensors=True)
pipe = pipe.to(device)

class ImagePromptGenerator:
    def __init__(self, model_name="gpt2"):
        # Initialize the text generation pipeline
        self.generator = pipeline("text-generation", model=model_name, use_auth_token=True)

    def generate_short_prompts(self, theme, num_prompts=5):
        # Generate short prompts based on the theme
        prompts = self.generator(f"{theme} concept", max_length=50, num_return_sequences=num_prompts)
        short_prompts = [prompt['generated_text'].strip() for prompt in prompts]
        return short_prompts

    def enhance_prompt(self, short_prompt):
        # Enhance the short prompt into a more detailed long prompt
        long_prompt = self.generator(f"Elaborate: {short_prompt}", max_length=100, num_return_sequences=1)
        return long_prompt[0]['generated_text'].strip()

    def generate_prompts_csv(self, theme):
        # Generate short prompts and enhance them
        short_prompts = self.generate_short_prompts(theme)
        long_prompts = [self.enhance_prompt(sp) for sp in short_prompts]
        # Create a DataFrame
        df = pd.DataFrame({"short": short_prompts, "long": long_prompts})
        return df.to_csv(index=False)

def generate_and_save_prompts(theme):
    generator = ImagePromptGenerator()
    csv_content = generator.generate_prompts_csv(theme)
    return csv_content

def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    generator = torch.Generator().manual_seed(seed)
    
    image = pipe(
        prompt=prompt, 
        negative_prompt=negative_prompt,
        guidance_scale=guidance_scale, 
        num_inference_steps=num_inference_steps, 
        width=width, 
        height=height,
        generator=generator
    ).images[0]
    
    return image

def gradio_interface(theme):
    # Generate image prompts based on theme
    csv_content = generate_and_save_prompts(theme)
    return gr.File(content=csv_content, file_name=f"{theme}_image_prompts.csv")

css = """
#col-container {
    margin: 0 auto;
    max-width: 520px;
}
"""

# Determine the computational power available
power_device = "GPU" if torch.cuda.is_available() else "CPU"

with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):
        gr.Markdown(f"""
        # Text-to-Image Gradio Template
        Currently running on {power_device}.
        """)

        with gr.Row():
            theme = gr.Textbox(label="Theme for Image Generation", placeholder="Enter a theme to generate prompts")
            prompt = gr.Textbox(label="Prompt for Image Generation", placeholder="Enter your prompt here or select from generated prompts", show_label=False)
            generate_prompts_button = gr.Button("Generate Prompts")

        with gr.Row():
            run_button = gr.Button("Run")
        
        result = gr.Image(label="Result", show_label=False)

        with gr.Accordion("Advanced Settings", open=False):
            negative_prompt = gr.Textbox(label="Negative prompt", placeholder="Enter a negative prompt")
            seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
            randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
            
            with gr.Row():
                width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=512)
                height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=512)
            
            with gr.Row():
                guidance_scale = gr.Slider(label="Guidance scale", minimum=0.0, maximum=10.0, step=0.1, value=7.5)
                num_inference_steps = gr.Slider(label="Number of inference steps", minimum=1, maximum=250, step=1, value=50)

        generate_prompts_button.click(
            fn=gradio_interface,
            inputs=[theme],
            outputs=[gr.File(label="Download Generated Prompts CSV")]
        )

        run_button.click(
            fn=infer,
            inputs=[prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
            outputs=[result]
        )

demo.launch()
'''
Explanation:

    Class ImagePromptGenerator: This class now includes methods to generate short prompts, enhance them, and output a CSV.

    generate_and_save_prompts Function: This function generates a CSV of prompts based on the theme.

    infer Function: This function generates an image based on the provided parameters using the diffusion model.

    Gradio Interface: The interface now includes:
        A textbox to input the theme for generating prompts.
        A button to generate prompts based on the theme.
        The original image generation interface with advanced settings.

    Button Actions:
        Generate Prompts Button: Generates a list of prompts as a downloadable CSV file.
        Run Button: Generates an image based on the provided prompt and settings.
'''