import io import torch import requests import PIL.Image import PIL.ImageOps from fastapi import FastAPI, UploadFile, File, Form from fastapi.responses import Response import gradio as gr from diffusers import ( StableDiffusionInstructPix2PixPipeline, EulerAncestralDiscreteScheduler ) # ========================= # Config # ========================= MODEL_ID = "peter-sushko/RealEdit" FIXED_STEPS = 50 FIXED_GUIDANCE_SCALE = 2.0 # ========================= # App # ========================= app = FastAPI(title="RealEdit API") print("Loading RealEdit model...") # Detect device use_cuda = torch.cuda.is_available() device = "cuda" if use_cuda else "cpu" dtype = torch.float16 if use_cuda else torch.float32 print(f"Using device: {device}, dtype: {dtype}") # Load pipeline pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained( MODEL_ID, torch_dtype=dtype, safety_checker=None ) pipe.scheduler = EulerAncestralDiscreteScheduler.from_config( pipe.scheduler.config ) pipe = pipe.to(device) print("Model loaded successfully!") # ========================= # Core inference # ========================= @torch.inference_mode() def run_inference(image, prompt): if device == "cuda": with torch.autocast("cuda"): result = pipe( prompt=prompt, image=image, num_inference_steps=FIXED_STEPS, image_guidance_scale=FIXED_GUIDANCE_SCALE ).images[0] else: result = pipe( prompt=prompt, image=image, num_inference_steps=FIXED_STEPS, image_guidance_scale=FIXED_GUIDANCE_SCALE ).images[0] return result def load_image_from_url(url: str): response = requests.get(url, stream=True, timeout=10) response.raise_for_status() image = PIL.Image.open(response.raw) image = PIL.ImageOps.exif_transpose(image) image = image.convert("RGB") return image # ========================= # API: upload image # ========================= @app.post("/edit") async def edit_image_api( prompt: str = Form(...), image: UploadFile = File(...) ): input_image = PIL.Image.open(image.file) input_image = PIL.ImageOps.exif_transpose(input_image) input_image = input_image.convert("RGB") output_image = run_inference(input_image, prompt) buf = io.BytesIO() output_image.save(buf, format="PNG") buf.seek(0) return Response(content=buf.read(), media_type="image/png") # ========================= # API: image URL # ========================= @app.post("/edit_url") async def edit_image_from_url( image_url: str = Form(...), prompt: str = Form(...) ): input_image = load_image_from_url(image_url) output_image = run_inference(input_image, prompt) buf = io.BytesIO() output_image.save(buf, format="PNG") buf.seek(0) return Response(content=buf.read(), media_type="image/png") # ========================= # Gradio UI # ========================= def gradio_edit(image, prompt): return run_inference(image, prompt) gradio_ui = gr.Interface( fn=gradio_edit, inputs=[ gr.Image(type="pil", label="Input Image"), gr.Textbox(label="Edit Prompt", value="give him a crown") ], outputs=gr.Image(type="pil", label="Output Image"), title="RealEdit (InstructPix2Pix)", description=( "Fixed settings: " f"steps={FIXED_STEPS}, guidance_scale={FIXED_GUIDANCE_SCALE}" ) ) # ⚠️ Mount UI at ROOT for Hugging Face Spaces app = gr.mount_gradio_app(app, gradio_ui, path="/")