|
|
import spaces |
|
|
import torch |
|
|
from diffusers import Flux2Pipeline |
|
|
from huggingface_hub import get_token |
|
|
import requests |
|
|
import io |
|
|
import gradio as gr |
|
|
from PIL import Image |
|
|
import os |
|
|
|
|
|
|
|
|
repo_id = "diffusers/FLUX.2-dev-bnb-4bit" |
|
|
torch_dtype = torch.bfloat16 |
|
|
|
|
|
print("Starting Flux2 Image Generator...") |
|
|
|
|
|
|
|
|
pipe = None |
|
|
|
|
|
def load_pipeline(): |
|
|
"""Lazy load the pipeline when needed.""" |
|
|
global pipe |
|
|
if pipe is None: |
|
|
print("Loading Flux2 pipeline...") |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
print(f"Using device: {device}") |
|
|
|
|
|
try: |
|
|
pipe = Flux2Pipeline.from_pretrained( |
|
|
repo_id, |
|
|
text_encoder=None, |
|
|
torch_dtype=torch_dtype, |
|
|
device_map="cuda" |
|
|
) |
|
|
print("Pipeline loaded successfully!") |
|
|
except Exception as e: |
|
|
print(f"Error loading pipeline: {e}") |
|
|
raise |
|
|
return pipe |
|
|
|
|
|
def remote_text_encoder(prompts): |
|
|
"""Encode prompts using remote text encoder API.""" |
|
|
try: |
|
|
token = get_token() |
|
|
if not token: |
|
|
raise ValueError("HuggingFace token not found. Please login using 'huggingface-cli login'") |
|
|
|
|
|
response = requests.post( |
|
|
"https://remote-text-encoder-flux-2.huggingface.co/predict", |
|
|
json={"prompt": prompts}, |
|
|
headers={ |
|
|
"Authorization": f"Bearer {token}", |
|
|
"Content-Type": "application/json" |
|
|
}, |
|
|
timeout=60 |
|
|
) |
|
|
response.raise_for_status() |
|
|
prompt_embeds = torch.load(io.BytesIO(response.content)) |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
return prompt_embeds.to(device) |
|
|
except Exception as e: |
|
|
raise Exception(f"Failed to encode prompt: {str(e)}") |
|
|
|
|
|
def get_duration(prompt: str, input_image: Image.Image = None, num_inference_steps: int = 28, guidance_scale: float = 4.0, seed: int = 42, progress=None): |
|
|
"""Calculate dynamic GPU duration based on inference steps and input image.""" |
|
|
num_images = 0 if input_image is None else 1 |
|
|
step_duration = 1 + 0.7 * num_images |
|
|
return max(65, num_inference_steps * step_duration + 10) |
|
|
|
|
|
@spaces.GPU(duration=get_duration) |
|
|
def generate_image( |
|
|
prompt: str, |
|
|
input_image: Image.Image = None, |
|
|
num_inference_steps: int = 28, |
|
|
guidance_scale: float = 4.0, |
|
|
seed: int = 42, |
|
|
progress=gr.Progress() |
|
|
): |
|
|
""" |
|
|
Generate an image using Flux2 based on text prompt and optional input image. |
|
|
|
|
|
Args: |
|
|
prompt: Text description of the desired image |
|
|
input_image: Optional input image for image-to-image generation |
|
|
num_inference_steps: Number of denoising steps (higher = better quality but slower) |
|
|
guidance_scale: How closely to follow the prompt (higher = more strict) |
|
|
seed: Random seed for reproducibility (-1 for random) |
|
|
""" |
|
|
print(f"=== Starting generation ===") |
|
|
print(f"Prompt: {prompt[:100]}...") |
|
|
print(f"CUDA available: {torch.cuda.is_available()}") |
|
|
|
|
|
if not prompt or prompt.strip() == "": |
|
|
raise gr.Error("Please enter a prompt!") |
|
|
|
|
|
progress(0, desc="Loading model...") |
|
|
|
|
|
try: |
|
|
|
|
|
print("Loading pipeline...") |
|
|
pipeline = load_pipeline() |
|
|
print("Pipeline loaded successfully") |
|
|
|
|
|
progress(0.1, desc="Encoding prompt...") |
|
|
print("Encoding prompt...") |
|
|
|
|
|
|
|
|
try: |
|
|
prompt_embeds = remote_text_encoder(prompt) |
|
|
print(f"Prompt embeds shape: {prompt_embeds.shape}") |
|
|
except Exception as e: |
|
|
print(f"Error encoding prompt: {str(e)}") |
|
|
raise gr.Error(f"Failed to encode prompt. Please check your HuggingFace token. Error: {str(e)}") |
|
|
|
|
|
progress(0.3, desc="Generating image...") |
|
|
|
|
|
|
|
|
generator_device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
print(f"Generator device: {generator_device}") |
|
|
|
|
|
if seed == -1: |
|
|
import random |
|
|
seed = random.randint(0, 2**32 - 1) |
|
|
|
|
|
print(f"Using seed: {seed}") |
|
|
generator = torch.Generator(device=generator_device).manual_seed(int(seed)) |
|
|
|
|
|
|
|
|
pipe_kwargs = { |
|
|
"prompt_embeds": prompt_embeds, |
|
|
"generator": generator, |
|
|
"num_inference_steps": int(num_inference_steps), |
|
|
"guidance_scale": float(guidance_scale), |
|
|
} |
|
|
|
|
|
|
|
|
if input_image is not None: |
|
|
pipe_kwargs["image"] = input_image |
|
|
progress(0.4, desc="Processing input image...") |
|
|
print("Processing with input image") |
|
|
|
|
|
print(f"Starting generation with {num_inference_steps} steps...") |
|
|
|
|
|
|
|
|
with torch.inference_mode(): |
|
|
result = pipeline(**pipe_kwargs) |
|
|
image = result.images[0] |
|
|
|
|
|
print("Generation complete!") |
|
|
progress(1.0, desc="Done!") |
|
|
|
|
|
return image |
|
|
|
|
|
except gr.Error: |
|
|
|
|
|
raise |
|
|
except Exception as e: |
|
|
import traceback |
|
|
error_msg = f"Error generating image: {str(e)}\n{traceback.format_exc()}" |
|
|
print(error_msg) |
|
|
|
|
|
|
|
|
if "CUDA" in str(e): |
|
|
raise gr.Error(f"GPU Error: {str(e)}. The model requires GPU to run.") |
|
|
elif "token" in str(e).lower() or "401" in str(e): |
|
|
raise gr.Error("Authentication failed. Please ensure your HuggingFace token is set correctly.") |
|
|
elif "timeout" in str(e).lower(): |
|
|
raise gr.Error("Request timed out. Please try again.") |
|
|
else: |
|
|
raise gr.Error(f"Error: {str(e)}") |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks( |
|
|
title="Flux2 Image Generator", |
|
|
) as demo: |
|
|
gr.Markdown( |
|
|
""" |
|
|
# 🎨 Flux2 Image Generator |
|
|
Generate stunning images using **FLUX.2-dev** with 4-bit quantization for efficient inference. |
|
|
|
|
|
Supports both **text-to-image** and **image-to-image** generation. |
|
|
|
|
|
⚡ **Powered by Hugging Face Zero GPU** - Automatic GPU allocation on demand! |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("### 📝 Input") |
|
|
|
|
|
prompt_input = gr.Textbox( |
|
|
label="Prompt", |
|
|
placeholder="Describe the image you want to generate...", |
|
|
lines=4, |
|
|
value="Realistic macro photograph of a hermit crab using a soda can as its shell, partially emerging from the can, captured with sharp detail and natural colors, on a sunlit beach with soft shadows and a shallow depth of field, with blurred ocean waves in the background." |
|
|
) |
|
|
|
|
|
image_input = gr.Image( |
|
|
label="Input Image (Optional)", |
|
|
type="pil", |
|
|
sources=["upload", "clipboard"], |
|
|
height=300 |
|
|
) |
|
|
|
|
|
gr.Markdown("### ⚙️ Parameters") |
|
|
|
|
|
with gr.Row(): |
|
|
num_steps = gr.Slider( |
|
|
minimum=1, |
|
|
maximum=100, |
|
|
value=28, |
|
|
step=1, |
|
|
label="Inference Steps", |
|
|
info="More steps = better quality but slower" |
|
|
) |
|
|
|
|
|
guidance = gr.Slider( |
|
|
minimum=1.0, |
|
|
maximum=15.0, |
|
|
value=4.0, |
|
|
step=0.5, |
|
|
label="Guidance Scale", |
|
|
info="How closely to follow the prompt" |
|
|
) |
|
|
|
|
|
seed_input = gr.Number( |
|
|
label="Seed", |
|
|
value=42, |
|
|
precision=0, |
|
|
info="Use -1 for random seed" |
|
|
) |
|
|
|
|
|
generate_btn = gr.Button( |
|
|
"🚀 Generate Image", |
|
|
variant="primary", |
|
|
size="lg", |
|
|
) |
|
|
|
|
|
gr.Markdown( |
|
|
""" |
|
|
### 💡 Tips |
|
|
- **Text-to-Image**: Just enter a prompt and click generate |
|
|
- **Image-to-Image**: Upload an image and describe the changes |
|
|
- Start with 28 steps for a good balance of quality and speed |
|
|
- Higher guidance scale follows your prompt more strictly |
|
|
- Use the same seed to reproduce results |
|
|
- First generation may take longer as the model loads |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("### 🖼️ Output") |
|
|
|
|
|
output_image = gr.Image( |
|
|
label="Generated Image", |
|
|
type="pil", |
|
|
height=600 |
|
|
) |
|
|
|
|
|
gr.Markdown( |
|
|
""" |
|
|
### 📊 Examples |
|
|
Try these prompts for inspiration! |
|
|
""" |
|
|
) |
|
|
|
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
[ |
|
|
"A serene landscape with mountains at sunset, vibrant orange and pink sky, reflected in a calm lake, photorealistic", |
|
|
None, |
|
|
28, |
|
|
4.0, |
|
|
42 |
|
|
], |
|
|
[ |
|
|
"A futuristic cityscape at night, neon lights, flying cars, cyberpunk style, highly detailed", |
|
|
None, |
|
|
28, |
|
|
4.0, |
|
|
123 |
|
|
], |
|
|
[ |
|
|
"A cute robot reading a book in a cozy library, warm lighting, digital art style", |
|
|
None, |
|
|
28, |
|
|
4.0, |
|
|
456 |
|
|
], |
|
|
[ |
|
|
"Macro photography of a dew drop on a leaf, morning light, sharp focus, bokeh background", |
|
|
None, |
|
|
28, |
|
|
4.0, |
|
|
789 |
|
|
], |
|
|
], |
|
|
inputs=[prompt_input, image_input, num_steps, guidance, seed_input], |
|
|
outputs=output_image, |
|
|
cache_examples=False, |
|
|
) |
|
|
|
|
|
|
|
|
generate_btn.click( |
|
|
fn=generate_image, |
|
|
inputs=[prompt_input, image_input, num_steps, guidance, seed_input], |
|
|
outputs=output_image, |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
print("Launching Gradio interface...") |
|
|
demo.queue(max_size=20).launch() |