Spaces:
Build error
Build error
| import gradio as gr | |
| import torch | |
| from diffusers import FluxPipeline | |
| from transformers import BlipProcessor, BlipForConditionalGeneration | |
| # Set up device | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Load the FLUX.1-schnell text-to-image model via diffusers. | |
| pipe = FluxPipeline.from_pretrained( | |
| "black-forest-labs/FLUX.1-schnell", | |
| torch_dtype=torch.bfloat16 | |
| ) | |
| pipe.enable_model_cpu_offload() # helps save VRAM | |
| # Load an image captioning model (BLIP) to guess the prompt. | |
| caption_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") | |
| caption_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") | |
| caption_model.to(device) | |
| def play_game(initial_prompt: str, rounds: int): | |
| images = [] | |
| current_prompt = initial_prompt | |
| # Loop for the number of rounds specified | |
| for i in range(rounds): | |
| # Generate an image with FLUX.1-schnell. | |
| result = pipe( | |
| current_prompt, | |
| guidance_scale=0.0, | |
| num_inference_steps=4, # adjust for speed vs. quality | |
| generator=torch.Generator(device).manual_seed(42 + i) | |
| ) | |
| img = result.images[0] | |
| images.append(img) | |
| # Use the captioning model to "guess" the prompt from the image. | |
| inputs = caption_processor(images=img, return_tensors="pt").to(device) | |
| output = caption_model.generate(**inputs) | |
| guessed_prompt = caption_processor.decode(output[0], skip_special_tokens=True) | |
| # Update current prompt with the guessed caption. | |
| current_prompt = guessed_prompt | |
| return images | |
| # Build the Gradio interface. | |
| demo = gr.Interface( | |
| fn=play_game, | |
| inputs=[ | |
| gr.Textbox(label="Initial Prompt", placeholder="Enter your starting prompt..."), | |
| gr.Slider(minimum=1, maximum=10, step=1, label="Number of Rounds", value=3) | |
| ], | |
| outputs=gr.Gallery(label="Generated Images"), | |
| title="Flux Prompt Guessing Game", | |
| description=( | |
| "Enter an initial prompt and choose the number of rounds. " | |
| "The game will generate an image using FLUX.1-schnell, then the AI " | |
| "will guess the prompt from that image to generate the next one, and so on." | |
| ) | |
| ) | |
| demo.launch() |