frogleo's picture
Update app.py
89efdd8 verified
import io
import gradio as gr
import numpy as np
import random
import spaces
import torch
from diffusers import Flux2KleinPipeline
from PIL import Image
import base64
import gc
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
# Model repository IDs for 9B
REPO_ID_REGULAR = "black-forest-labs/FLUX.2-klein-base-9B"
REPO_ID_DISTILLED = "black-forest-labs/FLUX.2-klein-9B"
# Load both 9B models
print("Loading 9B Regular model...")
pipe_regular = Flux2KleinPipeline.from_pretrained(REPO_ID_REGULAR, torch_dtype=dtype)
pipe_regular.to("cuda")
print("Loading 9B Distilled model...")
pipe_distilled = Flux2KleinPipeline.from_pretrained(REPO_ID_DISTILLED, torch_dtype=dtype)
pipe_distilled.to("cuda")
# Dictionary for easy access
pipes = {
"Distilled (4 steps)": pipe_distilled,
"Base (50 steps)": pipe_regular,
}
# Default steps for each mode
DEFAULT_STEPS = {
"Distilled (4 steps)": 4,
"Base (50 steps)": 50,
}
DEFAULT_CFG = {
"Distilled (4 steps)": 1.0,
"Base (50 steps)": 4.0,
}
# -------------------- NSFW 检测模型加载 --------------------
try:
print("Loading NSFW detector...")
from transformers import AutoProcessor, AutoModelForImageClassification
nsfw_processor = AutoProcessor.from_pretrained("Falconsai/nsfw_image_detection")
nsfw_model = AutoModelForImageClassification.from_pretrained(
"Falconsai/nsfw_image_detection"
).to(device)
print("NSFW detector loaded successfully.")
except Exception as e:
print(f"Failed to load NSFW detector: {e}")
nsfw_model = None
nsfw_processor = None
# -----------------------------------------------------------
class GenerationError(Exception):
"""Custom exception for generation errors"""
pass
def detect_nsfw(image: Image.Image, threshold: float = 0.5) -> bool:
"""Returns True if image is NSFW"""
inputs = nsfw_processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = nsfw_model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
nsfw_score = probs[0][1].item() # label 1 = NSFW
return nsfw_score > threshold
def image_to_data_uri(img):
"""
Convert a PIL Image to a base64 data URI.
Args:
img: The PIL Image to convert.
Returns:
str: A data URI string containing the base64-encoded PNG image.
"""
buffered = io.BytesIO()
img.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
return f"data:image/png;base64,{img_str}"
def get_dimensions_from_image(image_list,width,height,match_input_ratio):
if image_list is None or len(image_list) == 0 or match_input_ratio == False:
return False, width, height, width, height # Default dimensions
# Get the first image to determine dimensions
img = image_list[0][0] # Gallery returns list of tuples (image, caption)
img_width, img_height = img.size
aspect_ratio = img_width / img_height
if aspect_ratio >= 1: # Landscape or square
source_width = 1024
source_height = int(1024 / aspect_ratio)
else: # Portrait
source_height = 1024
source_width = int(1024 * aspect_ratio)
# Round to nearest multiple of 8
new_width = round(source_width / 8) * 8
new_height = round(source_height / 8) * 8
# Ensure within valid range (minimum 256, maximum 1024)
new_width = max(256, min(1024, new_width))
new_height = max(256, min(1024, new_height))
return True, new_width, new_height, source_width, source_height
def update_steps_from_mode(mode_choice):
"""
Update inference steps and guidance scale based on the selected mode.
Args:
mode_choice (str): The selected mode, either "Distilled (4 steps)" or "Base (50 steps)".
Returns:
tuple: A tuple of (num_inference_steps, guidance_scale).
"""
return DEFAULT_STEPS[mode_choice], DEFAULT_CFG[mode_choice]
progress=gr.Progress()
@spaces.GPU(duration=85)
def _infer(
prompt: str,
input_images=None,
mode_choice: str = "Distilled (4 steps)",
seed: int = 42,
randomize_seed: bool = False,
width: int = 1024,
height: int = 1024,
num_inference_steps: int = 4,
guidance_scale: float = 4.0,
match_input_ratio: bool = False,
):
if progress:
progress(0, desc="Starting generation...")
def callback_fn(pipe, step, timestep, callback_kwargs):
print(f"[Step {step}] Timestep: {timestep}")
progress_value = (step+1.0)/num_inference_steps
progress(progress_value, desc=f"Image generating, {step + 1}/{num_inference_steps} steps")
return callback_kwargs
# Convert string inputs to proper types for MCP compatibility
if isinstance(seed, str):
seed = int(seed)
if isinstance(randomize_seed, str):
randomize_seed = randomize_seed.lower() == "true"
if isinstance(width, str):
width = int(width)
if isinstance(height, str):
height = int(height)
if isinstance(num_inference_steps, str):
num_inference_steps = int(num_inference_steps)
if isinstance(guidance_scale, str):
guidance_scale = float(guidance_scale)
if isinstance(match_input_ratio, str):
match_input_ratio = match_input_ratio.lower() == "true"
if randomize_seed:
seed = random.randint(0, MAX_SEED)
try:
# Select the appropriate pipeline based on mode choice
pipe = pipes[mode_choice]
# Prepare image list (convert None or empty gallery to None)
image_list = None
if input_images is not None and len(input_images) > 0:
image_list = []
for item in input_images:
image_list.append(item[0])
matched, new_width, new_height, source_width, source_height = get_dimensions_from_image(input_images, width, height, match_input_ratio)
if image_list and len(image_list) > 0:
for img in image_list:
# NSFW 检测
if nsfw_model and nsfw_processor:
if detect_nsfw(img):
msg = "The input image contains NSFW content and cannot be generated. Please modify the input image or prompt and try again."
raise Exception(msg)
# 1. Upsampling (Network bound)
final_prompt = prompt
generator = torch.Generator(device=device).manual_seed(seed)
pipe_kwargs = {
"prompt": final_prompt,
"height": new_height,
"width": new_width,
"num_inference_steps": num_inference_steps,
"guidance_scale": guidance_scale,
"generator": generator,
"callback_on_step_end":callback_fn,
}
# Add images if provided
if image_list is not None:
pipe_kwargs["image"] = image_list
image = pipe(**pipe_kwargs).images[0]
# --- 后处理:缩放回原始尺寸 ---
if matched:
# 使用 LANCZOS 滤镜保持高质量缩放
image = image.resize((source_width, source_height), resample=Image.LANCZOS)
# ---------------------------
# NSFW 检测
if nsfw_model and nsfw_processor:
if detect_nsfw(image):
msg = "Generated image contains NSFW content and cannot be displayed. Please modify the input image or prompt and try again."
raise Exception(msg)
progress(1, desc="Complete")
info = {
"status": "success"
}
return image, seed, info
except GenerationError as e:
error_info = {
"error": str(e),
"status": "failed",
}
return None, None, error_info
except Exception as e:
error_info = {
"error": str(e),
"status": "failed",
}
return None, None, error_info
finally:
# Cleanup
torch.cuda.empty_cache()
gc.collect()
def infer(
prompt: str,
input_images=None,
mode_choice: str = "Distilled (4 steps)",
seed: int = 42,
randomize_seed: bool = False,
width: int = 1024,
height: int = 1024,
num_inference_steps: int = 4,
guidance_scale: float = 4.0,
match_input_ratio: bool = False,
):
# 调用 GPU 函数
image, seed, info = _infer(prompt,input_images,mode_choice,seed,randomize_seed,width,height,num_inference_steps,guidance_scale,match_input_ratio)
# 如果出错,抛出异常
if info["status"] == "failed":
raise gr.Error(info["error"])
# 返回图片
return image, seed
examples = [
["Create a vase on a table in living room, the color of the vase is a gradient of color, starting with #02eb3c color and finishing with #edfa3c. The flowers inside the vase have the color #ff0088"],
["Photorealistic infographic showing the complete Berlin TV Tower (Fernsehturm) from ground base to antenna tip, full vertical view with entire structure visible including concrete shaft, metallic sphere, and antenna spire. Slight upward perspective angle looking up toward the iconic sphere, perfectly centered on clean white background. Left side labels with thin horizontal connector lines: the text '368m' in extra large bold dark grey numerals (#2D3748) positioned at exactly the antenna tip with 'TOTAL HEIGHT' in small caps below. The text '207m' in extra large bold with 'TELECAFÉ' in small caps below, with connector line touching the sphere precisely at the window level. Right side label with horizontal connector line touching the sphere's equator: the text '32m' in extra large bold dark grey numerals with 'SPHERE DIAMETER' in small caps below. Bottom section arranged in three balanced columns: Left - Large text '986' in extra bold dark grey with 'STEPS' in caps below. Center - 'BERLIN TV TOWER' in bold caps with 'FERNSEHTURM' in lighter weight below. Right - 'INAUGURATED' in bold caps with 'OCTOBER 3, 1969' below. All typography in modern sans-serif font (such as Inter or Helvetica), color #2D3748, clean minimal technical diagram style. Horizontal connector lines are thin, precise, and clearly visible, touching the tower structure at exact corresponding measurement points. Professional architectural elevation drawing aesthetic with dynamic low angle perspective creating sense of height and grandeur, poster-ready infographic design with perfect visual hierarchy."],
["Soaking wet capybara taking shelter under a banana leaf in the rainy jungle, close up photo"],
["A kawaii die-cut sticker of a chubby orange cat, featuring big sparkly eyes and a happy smile with paws raised in greeting and a heart-shaped pink nose. The design should have smooth rounded lines with black outlines and soft gradient shading with pink cheeks."],
]
examples_images = [
["The person from image 1 is petting the cat from image 2, the bird from image 3 is next to them", ["woman1.webp", "cat_window.webp", "bird.webp"]]
]
css = """
#col-container {
margin: 0 auto;
max-width: 1200px;
}
"""
title = "# AI Image Generator from Image"
description = "Transform your photos into stunning art with our AI Image Generator! Simply upload an image to remix, restyle, and reimagine it instantly. This Space offers a limited free demo. If you reach your usage limit or want faster generation speeds, please visit our full platform at [AI Image Generator from Image](https://www.image2image.ai/) to continue creating."
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(title)
gr.Markdown(description)
with gr.Row():
with gr.Column():
gr.Markdown("### Input")
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=2,
placeholder="Enter your prompt",
container=False,
scale=3
)
run_button = gr.Button("Run", scale=1)
with gr.Accordion("Input image(s) (optional)", open=True):
input_images = gr.Gallery(
label="Input Image(s)",
type="pil",
columns=3,
rows=1,
)
mode_choice = gr.Radio(
label="Mode",
choices=["Distilled (4 steps)", "Base (50 steps)"],
value="Distilled (4 steps)",
)
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Group():
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=8,
value=1024,
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=8,
value=1024,
)
match_input = gr.Checkbox(label="Match input ratio", value=False)
with gr.Row():
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=100,
step=1,
value=4,
)
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=10.0,
step=0.1,
value=1.0,
)
with gr.Column():
gr.Markdown("### Output")
result = gr.Image(label="Result", show_label=False)
gr.Examples(
examples=examples,
fn=infer,
inputs=[prompt],
outputs=[result, seed],
cache_examples=True,
cache_mode="lazy"
)
gr.Examples(
examples=examples_images,
fn=infer,
inputs=[prompt, input_images],
outputs=[result, seed],
cache_examples=True,
cache_mode="lazy"
)
# Auto-update steps when mode changes
mode_choice.change(
fn=update_steps_from_mode,
inputs=[mode_choice],
outputs=[num_inference_steps, guidance_scale]
)
gr.on(
triggers=[run_button.click, prompt.submit],
fn=infer,
inputs=[prompt, input_images, mode_choice, seed, randomize_seed, width, height, num_inference_steps, guidance_scale, match_input],
outputs=[result, seed],
api_name="generate" # Explicit API name for MCP tool
)
# Launch with MCP server enabled
demo.launch(mcp_server=True)