ig / app.py
fomext's picture
Update app.py
bf480f1 verified
Raw
History Blame Contribute Delete
3.19 kB
from fastapi import FastAPI, UploadFile, File, Form
import uuid, os
from PIL import Image
from rembg import remove
from realesrgan import RealESRGANer
from basicsr.archs.rrdbnet_arch import RRDBNet
from gfpgan import GFPGANer
from diffusers import StableDiffusionImg2ImgPipeline
import torch
app = FastAPI()
UPLOAD_DIR = "uploads"
OUTPUT_DIR = "outputs"
os.makedirs(UPLOAD_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# ===== Load models ONCE =====
# Real-ESRGAN
# ===== Real-ESRGAN =====
model = RRDBNet(
num_in_ch=3,
num_out_ch=3,
num_feat=64,
num_block=23,
num_grow_ch=32,
scale=4
)
upscaler = RealESRGANer(
scale=4,
model_path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth",
model=model,
tile=0,
tile_pad=10,
pre_pad=0,
half=(DEVICE == "cuda"),
device=DEVICE
)
#upscaler.load_weights("RealESRGAN_x4.pth")
# GFPGAN
gfpgan = GFPGANer(
model_path="https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/GFPGANv1.4.pth",
upscale=1,
arch="clean",
channel_multiplier=2,
device=DEVICE
)
# Stable Diffusion Img2Img
dtype = torch.float16 if DEVICE == "cuda" else torch.float32
sd_pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=dtype
).to(DEVICE)
# ===== Endpoints =====
@app.post("/remove-bg")
async def remove_bg(file: UploadFile = File(...)):
img = Image.open(file.file).convert("RGBA")
out = remove(img)
fname = f"{uuid.uuid4().hex}.png"
path = os.path.join(OUTPUT_DIR, fname)
out.save(path)
return {"file": fname}
@app.post("/enhance")
async def enhance(file: UploadFile = File(...)):
img = Image.open(file.file).convert("RGB")
out, _ = upscaler.enhance(img)
fname = f"{uuid.uuid4().hex}.png"
out.save(os.path.join(OUTPUT_DIR, fname))
return {"file": fname}
@app.post("/upscale")
async def upscale(file: UploadFile = File(...), scale: int = Form(2)):
img = Image.open(file.file).convert("RGB")
upscaler.scale = scale
out, _ = upscaler.enhance(img, outscale=scale)
fname = f"{uuid.uuid4().hex}.png"
out.save(os.path.join(OUTPUT_DIR, fname))
return {"file": fname}
@app.post("/retouch")
async def retouch(file: UploadFile = File(...)):
img = Image.open(file.file).convert("RGB")
_, _, out = gfpgan.enhance(
img,
has_aligned=False,
only_center_face=False,
paste_back=True
)
fname = f"{uuid.uuid4().hex}.png"
Image.fromarray(out).save(os.path.join(OUTPUT_DIR, fname))
return {"file": fname}
@app.post("/edit")
async def edit_image(
file: UploadFile = File(...),
prompt: str = Form(...),
strength: float = Form(0.6)
):
img = Image.open(file.file).convert("RGB").resize((512, 512))
result = sd_pipe(
prompt=prompt,
image=img,
strength=strength,
guidance_scale=7.5
).images[0]
fname = f"{uuid.uuid4().hex}.png"
result.save(os.path.join(OUTPUT_DIR, fname))
return {
"prompt": prompt,
"file": fname
}