frogleo's picture
Update app.py
7cf40f7 verified
import os
import torch
os.environ["ATTN_IMPLEMENTATION"] = "sdpa"
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'
import gc
import random
import warnings
import gradio as gr
import spaces
import logging
from diffusers import FluxControlNetModel
from diffusers.pipelines import FluxControlNetPipeline
from huggingface_hub import snapshot_download
from PIL import Image
# Enhanced logging configuration
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
logger = logging.getLogger(__name__)
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
huggingface_token = os.getenv("HF_TOKEN")
model_path = snapshot_download(
repo_id="black-forest-labs/FLUX.1-dev",
repo_type="model",
ignore_patterns=["*.md", "*..gitattributes"],
local_dir="FLUX.1-dev",
token=huggingface_token, # type a new token-id.
)
controlnet = FluxControlNetModel.from_pretrained(
"jasperai/Flux.1-dev-Controlnet-Upscaler", torch_dtype=torch.bfloat16
).to(device)
pipe = FluxControlNetPipeline.from_pretrained(
model_path, controlnet=controlnet, torch_dtype=torch.bfloat16
)
pipe.to(device)
MAX_SEED = 1000000
MAX_PIXEL_BUDGET = 1024 * 1024
# -------------------- NSFW 检测模型加载 --------------------
try:
logger.info("Loading NSFW detector...")
from transformers import AutoProcessor, AutoModelForImageClassification
nsfw_processor = AutoProcessor.from_pretrained("Falconsai/nsfw_image_detection")
nsfw_model = AutoModelForImageClassification.from_pretrained(
"Falconsai/nsfw_image_detection"
).to(device)
logger.info("NSFW detector loaded successfully.")
except Exception as e:
logger.error(f"Failed to load NSFW detector: {e}")
nsfw_model = None
nsfw_processor = None
# -----------------------------------------------------------
class GenerationError(Exception):
"""Custom exception for generation errors"""
pass
def detect_nsfw(image: Image.Image, threshold: float = 0.5) -> bool:
"""Returns True if image is NSFW"""
inputs = nsfw_processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = nsfw_model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
nsfw_score = probs[0][1].item() # label 1 = NSFW
return nsfw_score > threshold
# def process_input(input_image, upscale_factor, **kwargs):
# w, h = input_image.size
# w_original, h_original = w, h
# aspect_ratio = w / h
# was_resized = False
# if w * h * upscale_factor**2 > MAX_PIXEL_BUDGET:
# warnings.warn(
# f"Requested output image is too large ({w * upscale_factor}x{h * upscale_factor}). Resizing to ({int(aspect_ratio * MAX_PIXEL_BUDGET ** 0.5 // upscale_factor), int(MAX_PIXEL_BUDGET ** 0.5 // aspect_ratio // upscale_factor)}) pixels."
# )
# gr.Info(
# f"Requested output image is too large ({w * upscale_factor}x{h * upscale_factor}). Resizing input to ({int(aspect_ratio * MAX_PIXEL_BUDGET ** 0.5 // upscale_factor), int(MAX_PIXEL_BUDGET ** 0.5 // aspect_ratio // upscale_factor)}) pixels budget."
# )
# input_image = input_image.resize(
# (
# int(aspect_ratio * MAX_PIXEL_BUDGET**0.5 // upscale_factor),
# int(MAX_PIXEL_BUDGET**0.5 // aspect_ratio // upscale_factor),
# )
# )
# was_resized = True
# # resize to multiple of 8
# w, h = input_image.size
# w = w - w % 8
# h = h - h % 8
# return input_image.resize((w, h)), w_original, h_original, was_resized
def process_input(input_image, upscale_factor, **kwargs):
w, h = input_image.size
w_original, h_original = w, h
# 1. 计算当前配置下的总输出像素
total_output_pixels = w * h * (upscale_factor ** 2)
was_resized = False
# 2. 如果超过预算,进行等比例缩放
if total_output_pixels > MAX_PIXEL_BUDGET:
# 计算缩放比例:我们要让 (w*k) * (h*k) * (upscale**2) == MAX_PIXEL_BUDGET
# k = sqrt(MAX_PIXEL_BUDGET / (w * h * upscale**2))
scale_k = (MAX_PIXEL_BUDGET / total_output_pixels) ** 0.5
new_w = int(w * scale_k)
new_h = int(h * scale_k)
input_image = input_image.resize((new_w, new_h), Image.LANCZOS)
was_resized = True
logger.info(f"Resizing input from {w}x{h} to {new_w}x{new_h} to fit budget.")
# gr.Info(f"Input resized to {new_w}x{new_h} due to memory limits.")
# 3. 确保尺寸是 8 的倍数(FLUX 模型要求)
w, h = input_image.size
w = (w // 8) * 8
h = (h // 8) * 8
return input_image.resize((w, h)), w_original, h_original, was_resized
# def process_input(input_image, upscale_factor):
# w, h = input_image.size
# w_original, h_original = w, h
# out_w = w * upscale_factor
# out_h = h * upscale_factor
# was_resized = False
# if out_w * out_h > MAX_PIXEL_BUDGET:
# scale = (MAX_PIXEL_BUDGET / (out_w * out_h)) ** 0.5
# new_out_w = int(out_w * scale)
# new_out_h = int(out_h * scale)
# # 反推输入尺寸
# new_in_w = max(8, new_out_w // upscale_factor)
# new_in_h = max(8, new_out_h // upscale_factor)
# # 对齐到 8 的倍数
# new_in_w -= new_in_w % 8
# new_in_h -= new_in_h % 8
# gr.Info(f"Output too large ({out_w}x{out_h}), resizing input to {new_in_w}x{new_in_h}, (target output {new_in_w * upscale_factor}x{new_in_h * upscale_factor})")
# input_image = input_image.resize((new_in_w, new_in_h))
# was_resized = True
# else:
# # 即便不 resize,也统一对齐到 8
# w -= w % 8
# h -= h % 8
# input_image = input_image.resize((w, h))
# return input_image, w_original, h_original, was_resized
progress=gr.Progress()
@spaces.GPU#(duration=42)
def _infer(
seed,
randomize_seed,
input_image,
num_inference_steps,
upscale_factor,
controlnet_conditioning_scale,
):
def callback_fn(pipe, step, timestep, callback_kwargs):
print(f"[Step {step}] Timestep: {timestep}")
progress_value = (step+1.0)/num_inference_steps
progress(progress_value, desc=f"Image upscaling, {step + 1}/{num_inference_steps} steps")
return callback_kwargs
if randomize_seed:
seed = random.randint(0, MAX_SEED)
try:
input_image, w_original, h_original, was_resized = process_input(
input_image, upscale_factor
)
# rescale with upscale factor
w, h = input_image.size
control_image = input_image.resize((w * upscale_factor, h * upscale_factor))
generator = torch.Generator().manual_seed(seed)
image = pipe(
prompt="",
control_image=control_image,
controlnet_conditioning_scale=controlnet_conditioning_scale,
num_inference_steps=num_inference_steps,
guidance_scale=3.5,
height=control_image.size[1],
width=control_image.size[0],
generator=generator,
callback_on_step_end=callback_fn,
).images[0]
if was_resized:
logger.info(f"Resizing output image to targeted {w_original * upscale_factor}x{h_original * upscale_factor} size.")
# gr.Info(f"Resizing output image to targeted {w_original * upscale_factor}x{h_original * upscale_factor} size.")
# NSFW 检测
if nsfw_model and nsfw_processor:
if detect_nsfw(image):
msg = "Generated image contains NSFW content and cannot be displayed. Please upload a different image and try again."
raise Exception(msg)
# resize to target desired size
image = image.resize((w_original * upscale_factor, h_original * upscale_factor))
image.save("output.jpg")
progress(1, desc="Complete")
info = {
"status": "success"
}
return image, info, seed
except GenerationError as e:
error_info = {
"error": str(e),
"status": "failed",
}
return None, error_info, None
except Exception as e:
error_info = {
"error": str(e),
"status": "failed",
}
return None, error_info, None
finally:
# Cleanup
torch.cuda.empty_cache()
gc.collect()
def infer(
seed,
randomize_seed,
input_image,
num_inference_steps,
upscale_factor,
controlnet_conditioning_scale,
):
progress(0,desc="Starting")
# 调用 GPU 函数
image, info, seed = _infer(seed,randomize_seed,input_image,num_inference_steps,upscale_factor,controlnet_conditioning_scale)
# 如果出错,抛出异常
if info["status"] == "failed":
raise gr.Error(info["error"])
# 返回图片
return image, seed
examples = [
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
"An astronaut riding a green horse",
"A delicious ceviche cheesecake slice",
]
css = """
#col-container {
margin: 0 auto;
max-width: 1200px;
}
"""
title = "# AI Image Upscaler"
description = "Enhance your photos instantly with our high-performance AI. This tool restores details, removes noise, and increases resolution while maintaining stunning clarity."
note = "*Note: This space has daily usage limits. If you have reached the limit or need faster processing, please visit [AI Image Upscaler](https://www.aiimgupscaler.com) for unlimited generations and premium support.*"
with gr.Blocks(css=css).queue() as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(title)
gr.Markdown(description)
gr.Markdown(note)
with gr.Row():
with gr.Column():
gr.Markdown("### Input")
input_im = gr.Image(label="Input Image", type="pil")
with gr.Accordion("Advanced Settings", open=False):
num_inference_steps = gr.Slider(
label="Number of Inference Steps",
minimum=8,
maximum=50,
step=1,
value=28,
)
upscale_factor = gr.Slider(
label="Upscale Factor",
minimum=1,
maximum=4,
step=1,
value=4,
)
controlnet_conditioning_scale = gr.Slider(
label="Controlnet Conditioning Scale",
minimum=0.1,
maximum=1.5,
step=0.1,
value=0.6,
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
run_button = gr.Button("Run", variant="primary")
with gr.Column():
gr.Markdown("### Output")
result = gr.Image(label="Result", show_label=False, interactive=False)
gr.Examples(
examples=[
[42, False, "examples/image_1.jpg", 28, 4, 0.6],
[42, False, "examples/image_2.jpg", 28, 4, 0.6],
[42, False, "examples/image_3.jpg", 28, 4, 0.6],
[42, False, "examples/image_4.jpg", 28, 4, 0.6],
[42, False, "examples/image_5.jpg", 28, 4, 0.6],
[42, False, "examples/image_6.jpg", 28, 4, 0.6],
],
inputs=[
seed,
randomize_seed,
input_im,
num_inference_steps,
upscale_factor,
controlnet_conditioning_scale,
],
fn=infer,
outputs=[result, seed],
cache_examples="lazy",
)
run_button.click(
fn=infer,
inputs=[
seed,
randomize_seed,
input_im,
num_inference_steps,
upscale_factor,
controlnet_conditioning_scale,
],
outputs=[result, seed]
)
if __name__ == "__main__":
demo.launch()