|
|
import torch |
|
|
from diffusers import Flux2Pipeline, Flux2Transformer2DModel |
|
|
from diffusers.utils import load_image |
|
|
from huggingface_hub import get_token |
|
|
import requests |
|
|
import io |
|
|
import gradio as gr |
|
|
from PIL import Image |
|
|
|
|
|
repo_id = "diffusers/FLUX.2-dev-bnb-4bit" |
|
|
device = "cuda:0" |
|
|
torch_dtype = torch.bfloat16 |
|
|
|
|
|
def remote_text_encoder(prompts): |
|
|
response = requests.post( |
|
|
"https://remote-text-encoder-flux-2.huggingface.co/predict", |
|
|
json={"prompt": prompts}, |
|
|
headers={ |
|
|
"Authorization": f"Bearer {get_token()}", |
|
|
"Content-Type": "application/json" |
|
|
} |
|
|
) |
|
|
prompt_embeds = torch.load(io.BytesIO(response.content)) |
|
|
|
|
|
return prompt_embeds.to(device) |
|
|
|
|
|
|
|
|
print("Loading Flux2 pipeline...") |
|
|
pipe = Flux2Pipeline.from_pretrained( |
|
|
repo_id, text_encoder=None, torch_dtype=torch_dtype |
|
|
).to(device) |
|
|
print("Pipeline loaded successfully!") |
|
|
|
|
|
def generate_image( |
|
|
prompt: str, |
|
|
input_image: Image.Image = None, |
|
|
num_inference_steps: int = 28, |
|
|
guidance_scale: float = 4.0, |
|
|
seed: int = 42, |
|
|
progress=gr.Progress() |
|
|
): |
|
|
""" |
|
|
Generate an image using Flux2 based on text prompt and optional input image. |
|
|
|
|
|
Args: |
|
|
prompt: Text description of the desired image |
|
|
input_image: Optional input image for image-to-image generation |
|
|
num_inference_steps: Number of denoising steps (higher = better quality but slower) |
|
|
guidance_scale: How closely to follow the prompt (higher = more strict) |
|
|
seed: Random seed for reproducibility (-1 for random) |
|
|
""" |
|
|
if not prompt or prompt.strip() == "": |
|
|
raise gr.Error("Please enter a prompt!") |
|
|
|
|
|
progress(0, desc="Encoding prompt...") |
|
|
|
|
|
try: |
|
|
|
|
|
prompt_embeds = remote_text_encoder(prompt) |
|
|
|
|
|
progress(0.3, desc="Generating image...") |
|
|
|
|
|
|
|
|
if seed == -1: |
|
|
generator = torch.Generator(device=device) |
|
|
else: |
|
|
generator = torch.Generator(device=device).manual_seed(seed) |
|
|
|
|
|
|
|
|
pipe_kwargs = { |
|
|
"prompt_embeds": prompt_embeds, |
|
|
"generator": generator, |
|
|
"num_inference_steps": num_inference_steps, |
|
|
"guidance_scale": guidance_scale, |
|
|
} |
|
|
|
|
|
|
|
|
if input_image is not None: |
|
|
pipe_kwargs["image"] = input_image |
|
|
progress(0.4, desc="Processing input image...") |
|
|
|
|
|
|
|
|
image = pipe(**pipe_kwargs).images[0] |
|
|
|
|
|
progress(1.0, desc="Done!") |
|
|
|
|
|
return image |
|
|
|
|
|
except Exception as e: |
|
|
raise gr.Error(f"Error generating image: {str(e)}") |
|
|
|
|
|
|
|
|
with gr.Blocks(title="Flux2 Image Generator", theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown( |
|
|
""" |
|
|
# 🎨 Flux2 Image Generator |
|
|
Generate stunning images using FLUX.2-dev with 4-bit quantization. |
|
|
Supports both **text-to-image** and **image-to-image** generation. |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("### 📝 Input") |
|
|
|
|
|
prompt_input = gr.Textbox( |
|
|
label="Prompt", |
|
|
placeholder="Describe the image you want to generate...", |
|
|
lines=4, |
|
|
value="Realistic macro photograph of a hermit crab using a soda can as its shell, partially emerging from the can, captured with sharp detail and natural colors, on a sunlit beach with soft shadows and a shallow depth of field, with blurred ocean waves in the background." |
|
|
) |
|
|
|
|
|
image_input = gr.Image( |
|
|
label="Input Image (Optional)", |
|
|
type="pil", |
|
|
sources=["upload", "clipboard"], |
|
|
height=300 |
|
|
) |
|
|
|
|
|
gr.Markdown("### ⚙️ Parameters") |
|
|
|
|
|
with gr.Row(): |
|
|
num_steps = gr.Slider( |
|
|
minimum=1, |
|
|
maximum=100, |
|
|
value=28, |
|
|
step=1, |
|
|
label="Inference Steps", |
|
|
info="More steps = better quality but slower" |
|
|
) |
|
|
|
|
|
guidance = gr.Slider( |
|
|
minimum=1.0, |
|
|
maximum=15.0, |
|
|
value=4.0, |
|
|
step=0.5, |
|
|
label="Guidance Scale", |
|
|
info="How closely to follow the prompt" |
|
|
) |
|
|
|
|
|
seed_input = gr.Number( |
|
|
label="Seed", |
|
|
value=42, |
|
|
precision=0, |
|
|
info="Use -1 for random seed" |
|
|
) |
|
|
|
|
|
generate_btn = gr.Button( |
|
|
"🚀 Generate Image", |
|
|
variant="primary", |
|
|
size="lg" |
|
|
) |
|
|
|
|
|
gr.Markdown( |
|
|
""" |
|
|
### 💡 Tips |
|
|
- **Text-to-Image**: Just enter a prompt and click generate |
|
|
- **Image-to-Image**: Upload an image and describe the changes |
|
|
- Start with 28 steps for a good balance of quality and speed |
|
|
- Higher guidance scale follows your prompt more strictly |
|
|
- Use the same seed to reproduce results |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("### 🖼️ Output") |
|
|
|
|
|
output_image = gr.Image( |
|
|
label="Generated Image", |
|
|
type="pil", |
|
|
height=600 |
|
|
) |
|
|
|
|
|
gr.Markdown( |
|
|
""" |
|
|
### 📊 Examples |
|
|
Try these prompts for inspiration! |
|
|
""" |
|
|
) |
|
|
|
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
[ |
|
|
"A serene landscape with mountains at sunset, vibrant orange and pink sky, reflected in a calm lake, photorealistic", |
|
|
None, |
|
|
28, |
|
|
4.0, |
|
|
42 |
|
|
], |
|
|
[ |
|
|
"A futuristic cityscape at night, neon lights, flying cars, cyberpunk style, highly detailed", |
|
|
None, |
|
|
28, |
|
|
4.0, |
|
|
123 |
|
|
], |
|
|
[ |
|
|
"A cute robot reading a book in a cozy library, warm lighting, digital art style", |
|
|
None, |
|
|
28, |
|
|
4.0, |
|
|
456 |
|
|
], |
|
|
[ |
|
|
"Macro photography of a dew drop on a leaf, morning light, sharp focus, bokeh background", |
|
|
None, |
|
|
28, |
|
|
4.0, |
|
|
789 |
|
|
], |
|
|
], |
|
|
inputs=[prompt_input, image_input, num_steps, guidance, seed_input], |
|
|
outputs=output_image, |
|
|
cache_examples=False, |
|
|
) |
|
|
|
|
|
|
|
|
generate_btn.click( |
|
|
fn=generate_image, |
|
|
inputs=[prompt_input, image_input, num_steps, guidance, seed_input], |
|
|
outputs=output_image, |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(share=False, server_name="0.0.0.0", server_port=7860) |
|
|
|