realedit-docker / app.py
yuu1234's picture
Add 6
d85719d
import io
import torch
import requests
import PIL.Image
import PIL.ImageOps
from fastapi import FastAPI, UploadFile, File, Form
from fastapi.responses import Response
import gradio as gr
from diffusers import (
StableDiffusionInstructPix2PixPipeline,
EulerAncestralDiscreteScheduler
)
# =========================
# Config
# =========================
MODEL_ID = "peter-sushko/RealEdit"
FIXED_STEPS = 50
FIXED_GUIDANCE_SCALE = 2.0
# =========================
# App
# =========================
app = FastAPI(title="RealEdit API")
print("Loading RealEdit model...")
# Detect device
use_cuda = torch.cuda.is_available()
device = "cuda" if use_cuda else "cpu"
dtype = torch.float16 if use_cuda else torch.float32
print(f"Using device: {device}, dtype: {dtype}")
# Load pipeline
pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(
MODEL_ID,
torch_dtype=dtype,
safety_checker=None
)
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
pipe.scheduler.config
)
pipe = pipe.to(device)
print("Model loaded successfully!")
# =========================
# Core inference
# =========================
@torch.inference_mode()
def run_inference(image, prompt):
if device == "cuda":
with torch.autocast("cuda"):
result = pipe(
prompt=prompt,
image=image,
num_inference_steps=FIXED_STEPS,
image_guidance_scale=FIXED_GUIDANCE_SCALE
).images[0]
else:
result = pipe(
prompt=prompt,
image=image,
num_inference_steps=FIXED_STEPS,
image_guidance_scale=FIXED_GUIDANCE_SCALE
).images[0]
return result
def load_image_from_url(url: str):
response = requests.get(url, stream=True, timeout=10)
response.raise_for_status()
image = PIL.Image.open(response.raw)
image = PIL.ImageOps.exif_transpose(image)
image = image.convert("RGB")
return image
# =========================
# API: upload image
# =========================
@app.post("/edit")
async def edit_image_api(
prompt: str = Form(...),
image: UploadFile = File(...)
):
input_image = PIL.Image.open(image.file)
input_image = PIL.ImageOps.exif_transpose(input_image)
input_image = input_image.convert("RGB")
output_image = run_inference(input_image, prompt)
buf = io.BytesIO()
output_image.save(buf, format="PNG")
buf.seek(0)
return Response(content=buf.read(), media_type="image/png")
# =========================
# API: image URL
# =========================
@app.post("/edit_url")
async def edit_image_from_url(
image_url: str = Form(...),
prompt: str = Form(...)
):
input_image = load_image_from_url(image_url)
output_image = run_inference(input_image, prompt)
buf = io.BytesIO()
output_image.save(buf, format="PNG")
buf.seek(0)
return Response(content=buf.read(), media_type="image/png")
# =========================
# Gradio UI
# =========================
def gradio_edit(image, prompt):
return run_inference(image, prompt)
gradio_ui = gr.Interface(
fn=gradio_edit,
inputs=[
gr.Image(type="pil", label="Input Image"),
gr.Textbox(label="Edit Prompt", value="give him a crown")
],
outputs=gr.Image(type="pil", label="Output Image"),
title="RealEdit (InstructPix2Pix)",
description=(
"Fixed settings: "
f"steps={FIXED_STEPS}, guidance_scale={FIXED_GUIDANCE_SCALE}"
)
)
# ⚠️ Mount UI at ROOT for Hugging Face Spaces
app = gr.mount_gradio_app(app, gradio_ui, path="/")