Spaces:
Running on Zero
Running on Zero
| import os | |
| import torch | |
| os.environ["ATTN_IMPLEMENTATION"] = "sdpa" | |
| os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128' | |
| import gc | |
| import random | |
| import warnings | |
| import gradio as gr | |
| import spaces | |
| import logging | |
| from diffusers import FluxControlNetModel | |
| from diffusers.pipelines import FluxControlNetPipeline | |
| from huggingface_hub import snapshot_download | |
| from PIL import Image | |
| # Enhanced logging configuration | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
| datefmt='%Y-%m-%d %H:%M:%S' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| if torch.cuda.is_available(): | |
| device = "cuda" | |
| else: | |
| device = "cpu" | |
| huggingface_token = os.getenv("HF_TOKEN") | |
| model_path = snapshot_download( | |
| repo_id="black-forest-labs/FLUX.1-dev", | |
| repo_type="model", | |
| ignore_patterns=["*.md", "*..gitattributes"], | |
| local_dir="FLUX.1-dev", | |
| token=huggingface_token, # type a new token-id. | |
| ) | |
| controlnet = FluxControlNetModel.from_pretrained( | |
| "jasperai/Flux.1-dev-Controlnet-Upscaler", torch_dtype=torch.bfloat16 | |
| ).to(device) | |
| pipe = FluxControlNetPipeline.from_pretrained( | |
| model_path, controlnet=controlnet, torch_dtype=torch.bfloat16 | |
| ) | |
| pipe.to(device) | |
| MAX_SEED = 1000000 | |
| MAX_PIXEL_BUDGET = 1024 * 1024 | |
| # -------------------- NSFW 检测模型加载 -------------------- | |
| try: | |
| logger.info("Loading NSFW detector...") | |
| from transformers import AutoProcessor, AutoModelForImageClassification | |
| nsfw_processor = AutoProcessor.from_pretrained("Falconsai/nsfw_image_detection") | |
| nsfw_model = AutoModelForImageClassification.from_pretrained( | |
| "Falconsai/nsfw_image_detection" | |
| ).to(device) | |
| logger.info("NSFW detector loaded successfully.") | |
| except Exception as e: | |
| logger.error(f"Failed to load NSFW detector: {e}") | |
| nsfw_model = None | |
| nsfw_processor = None | |
| # ----------------------------------------------------------- | |
| class GenerationError(Exception): | |
| """Custom exception for generation errors""" | |
| pass | |
| def detect_nsfw(image: Image.Image, threshold: float = 0.5) -> bool: | |
| """Returns True if image is NSFW""" | |
| inputs = nsfw_processor(images=image, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| outputs = nsfw_model(**inputs) | |
| probs = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
| nsfw_score = probs[0][1].item() # label 1 = NSFW | |
| return nsfw_score > threshold | |
| # def process_input(input_image, upscale_factor, **kwargs): | |
| # w, h = input_image.size | |
| # w_original, h_original = w, h | |
| # aspect_ratio = w / h | |
| # was_resized = False | |
| # if w * h * upscale_factor**2 > MAX_PIXEL_BUDGET: | |
| # warnings.warn( | |
| # f"Requested output image is too large ({w * upscale_factor}x{h * upscale_factor}). Resizing to ({int(aspect_ratio * MAX_PIXEL_BUDGET ** 0.5 // upscale_factor), int(MAX_PIXEL_BUDGET ** 0.5 // aspect_ratio // upscale_factor)}) pixels." | |
| # ) | |
| # gr.Info( | |
| # f"Requested output image is too large ({w * upscale_factor}x{h * upscale_factor}). Resizing input to ({int(aspect_ratio * MAX_PIXEL_BUDGET ** 0.5 // upscale_factor), int(MAX_PIXEL_BUDGET ** 0.5 // aspect_ratio // upscale_factor)}) pixels budget." | |
| # ) | |
| # input_image = input_image.resize( | |
| # ( | |
| # int(aspect_ratio * MAX_PIXEL_BUDGET**0.5 // upscale_factor), | |
| # int(MAX_PIXEL_BUDGET**0.5 // aspect_ratio // upscale_factor), | |
| # ) | |
| # ) | |
| # was_resized = True | |
| # # resize to multiple of 8 | |
| # w, h = input_image.size | |
| # w = w - w % 8 | |
| # h = h - h % 8 | |
| # return input_image.resize((w, h)), w_original, h_original, was_resized | |
| def process_input(input_image, upscale_factor, **kwargs): | |
| w, h = input_image.size | |
| w_original, h_original = w, h | |
| # 1. 计算当前配置下的总输出像素 | |
| total_output_pixels = w * h * (upscale_factor ** 2) | |
| was_resized = False | |
| # 2. 如果超过预算,进行等比例缩放 | |
| if total_output_pixels > MAX_PIXEL_BUDGET: | |
| # 计算缩放比例:我们要让 (w*k) * (h*k) * (upscale**2) == MAX_PIXEL_BUDGET | |
| # k = sqrt(MAX_PIXEL_BUDGET / (w * h * upscale**2)) | |
| scale_k = (MAX_PIXEL_BUDGET / total_output_pixels) ** 0.5 | |
| new_w = int(w * scale_k) | |
| new_h = int(h * scale_k) | |
| input_image = input_image.resize((new_w, new_h), Image.LANCZOS) | |
| was_resized = True | |
| logger.info(f"Resizing input from {w}x{h} to {new_w}x{new_h} to fit budget.") | |
| # gr.Info(f"Input resized to {new_w}x{new_h} due to memory limits.") | |
| # 3. 确保尺寸是 8 的倍数(FLUX 模型要求) | |
| w, h = input_image.size | |
| w = (w // 8) * 8 | |
| h = (h // 8) * 8 | |
| return input_image.resize((w, h)), w_original, h_original, was_resized | |
| # def process_input(input_image, upscale_factor): | |
| # w, h = input_image.size | |
| # w_original, h_original = w, h | |
| # out_w = w * upscale_factor | |
| # out_h = h * upscale_factor | |
| # was_resized = False | |
| # if out_w * out_h > MAX_PIXEL_BUDGET: | |
| # scale = (MAX_PIXEL_BUDGET / (out_w * out_h)) ** 0.5 | |
| # new_out_w = int(out_w * scale) | |
| # new_out_h = int(out_h * scale) | |
| # # 反推输入尺寸 | |
| # new_in_w = max(8, new_out_w // upscale_factor) | |
| # new_in_h = max(8, new_out_h // upscale_factor) | |
| # # 对齐到 8 的倍数 | |
| # new_in_w -= new_in_w % 8 | |
| # new_in_h -= new_in_h % 8 | |
| # gr.Info(f"Output too large ({out_w}x{out_h}), resizing input to {new_in_w}x{new_in_h}, (target output {new_in_w * upscale_factor}x{new_in_h * upscale_factor})") | |
| # input_image = input_image.resize((new_in_w, new_in_h)) | |
| # was_resized = True | |
| # else: | |
| # # 即便不 resize,也统一对齐到 8 | |
| # w -= w % 8 | |
| # h -= h % 8 | |
| # input_image = input_image.resize((w, h)) | |
| # return input_image, w_original, h_original, was_resized | |
| progress=gr.Progress() | |
| #(duration=42) | |
| def _infer( | |
| seed, | |
| randomize_seed, | |
| input_image, | |
| num_inference_steps, | |
| upscale_factor, | |
| controlnet_conditioning_scale, | |
| ): | |
| def callback_fn(pipe, step, timestep, callback_kwargs): | |
| print(f"[Step {step}] Timestep: {timestep}") | |
| progress_value = (step+1.0)/num_inference_steps | |
| progress(progress_value, desc=f"Image upscaling, {step + 1}/{num_inference_steps} steps") | |
| return callback_kwargs | |
| if randomize_seed: | |
| seed = random.randint(0, MAX_SEED) | |
| try: | |
| input_image, w_original, h_original, was_resized = process_input( | |
| input_image, upscale_factor | |
| ) | |
| # rescale with upscale factor | |
| w, h = input_image.size | |
| control_image = input_image.resize((w * upscale_factor, h * upscale_factor)) | |
| generator = torch.Generator().manual_seed(seed) | |
| image = pipe( | |
| prompt="", | |
| control_image=control_image, | |
| controlnet_conditioning_scale=controlnet_conditioning_scale, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=3.5, | |
| height=control_image.size[1], | |
| width=control_image.size[0], | |
| generator=generator, | |
| callback_on_step_end=callback_fn, | |
| ).images[0] | |
| if was_resized: | |
| logger.info(f"Resizing output image to targeted {w_original * upscale_factor}x{h_original * upscale_factor} size.") | |
| # gr.Info(f"Resizing output image to targeted {w_original * upscale_factor}x{h_original * upscale_factor} size.") | |
| # NSFW 检测 | |
| if nsfw_model and nsfw_processor: | |
| if detect_nsfw(image): | |
| msg = "Generated image contains NSFW content and cannot be displayed. Please upload a different image and try again." | |
| raise Exception(msg) | |
| # resize to target desired size | |
| image = image.resize((w_original * upscale_factor, h_original * upscale_factor)) | |
| image.save("output.jpg") | |
| progress(1, desc="Complete") | |
| info = { | |
| "status": "success" | |
| } | |
| return image, info, seed | |
| except GenerationError as e: | |
| error_info = { | |
| "error": str(e), | |
| "status": "failed", | |
| } | |
| return None, error_info, None | |
| except Exception as e: | |
| error_info = { | |
| "error": str(e), | |
| "status": "failed", | |
| } | |
| return None, error_info, None | |
| finally: | |
| # Cleanup | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| def infer( | |
| seed, | |
| randomize_seed, | |
| input_image, | |
| num_inference_steps, | |
| upscale_factor, | |
| controlnet_conditioning_scale, | |
| ): | |
| progress(0,desc="Starting") | |
| # 调用 GPU 函数 | |
| image, info, seed = _infer(seed,randomize_seed,input_image,num_inference_steps,upscale_factor,controlnet_conditioning_scale) | |
| # 如果出错,抛出异常 | |
| if info["status"] == "failed": | |
| raise gr.Error(info["error"]) | |
| # 返回图片 | |
| return image, seed | |
| examples = [ | |
| "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", | |
| "An astronaut riding a green horse", | |
| "A delicious ceviche cheesecake slice", | |
| ] | |
| css = """ | |
| #col-container { | |
| margin: 0 auto; | |
| max-width: 1200px; | |
| } | |
| """ | |
| title = "# AI Image Upscaler" | |
| description = "Enhance your photos instantly with our high-performance AI. This tool restores details, removes noise, and increases resolution while maintaining stunning clarity." | |
| note = "*Note: This space has daily usage limits. If you have reached the limit or need faster processing, please visit [AI Image Upscaler](https://www.aiimgupscaler.com) for unlimited generations and premium support.*" | |
| with gr.Blocks(css=css).queue() as demo: | |
| with gr.Column(elem_id="col-container"): | |
| gr.Markdown(title) | |
| gr.Markdown(description) | |
| gr.Markdown(note) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### Input") | |
| input_im = gr.Image(label="Input Image", type="pil") | |
| with gr.Accordion("Advanced Settings", open=False): | |
| num_inference_steps = gr.Slider( | |
| label="Number of Inference Steps", | |
| minimum=8, | |
| maximum=50, | |
| step=1, | |
| value=28, | |
| ) | |
| upscale_factor = gr.Slider( | |
| label="Upscale Factor", | |
| minimum=1, | |
| maximum=4, | |
| step=1, | |
| value=4, | |
| ) | |
| controlnet_conditioning_scale = gr.Slider( | |
| label="Controlnet Conditioning Scale", | |
| minimum=0.1, | |
| maximum=1.5, | |
| step=0.1, | |
| value=0.6, | |
| ) | |
| seed = gr.Slider( | |
| label="Seed", | |
| minimum=0, | |
| maximum=MAX_SEED, | |
| step=1, | |
| value=0, | |
| ) | |
| randomize_seed = gr.Checkbox(label="Randomize seed", value=True) | |
| run_button = gr.Button("Run", variant="primary") | |
| with gr.Column(): | |
| gr.Markdown("### Output") | |
| result = gr.Image(label="Result", show_label=False, interactive=False) | |
| gr.Examples( | |
| examples=[ | |
| [42, False, "examples/image_1.jpg", 28, 4, 0.6], | |
| [42, False, "examples/image_2.jpg", 28, 4, 0.6], | |
| [42, False, "examples/image_3.jpg", 28, 4, 0.6], | |
| [42, False, "examples/image_4.jpg", 28, 4, 0.6], | |
| [42, False, "examples/image_5.jpg", 28, 4, 0.6], | |
| [42, False, "examples/image_6.jpg", 28, 4, 0.6], | |
| ], | |
| inputs=[ | |
| seed, | |
| randomize_seed, | |
| input_im, | |
| num_inference_steps, | |
| upscale_factor, | |
| controlnet_conditioning_scale, | |
| ], | |
| fn=infer, | |
| outputs=[result, seed], | |
| cache_examples="lazy", | |
| ) | |
| run_button.click( | |
| fn=infer, | |
| inputs=[ | |
| seed, | |
| randomize_seed, | |
| input_im, | |
| num_inference_steps, | |
| upscale_factor, | |
| controlnet_conditioning_scale, | |
| ], | |
| outputs=[result, seed] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |