| | import gradio as gr |
| | import pandas as pd |
| | import numpy as np |
| | import random |
| | import torch |
| | from transformers import pipeline |
| | from diffusers import DiffusionPipeline |
| |
|
| | |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | MAX_SEED = np.iinfo(np.int32).max |
| | MAX_IMAGE_SIZE = 1024 |
| |
|
| | |
| | if torch.cuda.is_available(): |
| | torch.cuda.max_memory_allocated(device=device) |
| | pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16", use_safetensors=True) |
| | pipe.enable_xformers_memory_efficient_attention() |
| | else: |
| | pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", use_safetensors=True) |
| | pipe = pipe.to(device) |
| |
|
| | class ImagePromptGenerator: |
| | def __init__(self, model_name="gpt2"): |
| | |
| | self.generator = pipeline("text-generation", model=model_name, use_auth_token=True) |
| |
|
| | def generate_short_prompts(self, theme, num_prompts=5): |
| | |
| | prompts = self.generator(f"{theme} concept", max_length=50, num_return_sequences=num_prompts) |
| | short_prompts = [prompt['generated_text'].strip() for prompt in prompts] |
| | return short_prompts |
| |
|
| | def enhance_prompt(self, short_prompt): |
| | |
| | long_prompt = self.generator(f"Elaborate: {short_prompt}", max_length=100, num_return_sequences=1) |
| | return long_prompt[0]['generated_text'].strip() |
| |
|
| | def generate_prompts_csv(self, theme): |
| | |
| | short_prompts = self.generate_short_prompts(theme) |
| | long_prompts = [self.enhance_prompt(sp) for sp in short_prompts] |
| | |
| | df = pd.DataFrame({"short": short_prompts, "long": long_prompts}) |
| | return df.to_csv(index=False) |
| |
|
| | def generate_and_save_prompts(theme): |
| | generator = ImagePromptGenerator() |
| | csv_content = generator.generate_prompts_csv(theme) |
| | return csv_content |
| |
|
| | def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps): |
| | if randomize_seed: |
| | seed = random.randint(0, MAX_SEED) |
| | generator = torch.Generator().manual_seed(seed) |
| | |
| | image = pipe( |
| | prompt=prompt, |
| | negative_prompt=negative_prompt, |
| | guidance_scale=guidance_scale, |
| | num_inference_steps=num_inference_steps, |
| | width=width, |
| | height=height, |
| | generator=generator |
| | ).images[0] |
| | |
| | return image |
| |
|
| | def gradio_interface(theme): |
| | |
| | csv_content = generate_and_save_prompts(theme) |
| | return gr.File(content=csv_content, file_name=f"{theme}_image_prompts.csv") |
| |
|
| | css = """ |
| | #col-container { |
| | margin: 0 auto; |
| | max-width: 520px; |
| | } |
| | """ |
| |
|
| | |
| | power_device = "GPU" if torch.cuda.is_available() else "CPU" |
| |
|
| | with gr.Blocks(css=css) as demo: |
| | with gr.Column(elem_id="col-container"): |
| | gr.Markdown(f""" |
| | # Text-to-Image Gradio Template |
| | Currently running on {power_device}. |
| | """) |
| |
|
| | with gr.Row(): |
| | theme = gr.Textbox(label="Theme for Image Generation", placeholder="Enter a theme to generate prompts") |
| | prompt = gr.Textbox(label="Prompt for Image Generation", placeholder="Enter your prompt here or select from generated prompts", show_label=False) |
| | generate_prompts_button = gr.Button("Generate Prompts") |
| |
|
| | with gr.Row(): |
| | run_button = gr.Button("Run") |
| | |
| | result = gr.Image(label="Result", show_label=False) |
| |
|
| | with gr.Accordion("Advanced Settings", open=False): |
| | negative_prompt = gr.Textbox(label="Negative prompt", placeholder="Enter a negative prompt") |
| | seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0) |
| | randomize_seed = gr.Checkbox(label="Randomize seed", value=True) |
| | |
| | with gr.Row(): |
| | width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=512) |
| | height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=512) |
| | |
| | with gr.Row(): |
| | guidance_scale = gr.Slider(label="Guidance scale", minimum=0.0, maximum=10.0, step=0.1, value=7.5) |
| | num_inference_steps = gr.Slider(label="Number of inference steps", minimum=1, maximum=250, step=1, value=50) |
| |
|
| | generate_prompts_button.click( |
| | fn=gradio_interface, |
| | inputs=[theme], |
| | outputs=[gr.File(label="Download Generated Prompts CSV")] |
| | ) |
| |
|
| | run_button.click( |
| | fn=infer, |
| | inputs=[prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps], |
| | outputs=[result] |
| | ) |
| |
|
| | demo.launch() |
| | ''' |
| | Explanation: |
| | |
| | Class ImagePromptGenerator: This class now includes methods to generate short prompts, enhance them, and output a CSV. |
| | |
| | generate_and_save_prompts Function: This function generates a CSV of prompts based on the theme. |
| | |
| | infer Function: This function generates an image based on the provided parameters using the diffusion model. |
| | |
| | Gradio Interface: The interface now includes: |
| | A textbox to input the theme for generating prompts. |
| | A button to generate prompts based on the theme. |
| | The original image generation interface with advanced settings. |
| | |
| | Button Actions: |
| | Generate Prompts Button: Generates a list of prompts as a downloadable CSV file. |
| | Run Button: Generates an image based on the provided prompt and settings. |
| | ''' |