| from fastapi import FastAPI, UploadFile, File, Form |
| import uuid, os |
| from PIL import Image |
|
|
| from rembg import remove |
| from realesrgan import RealESRGANer |
| from basicsr.archs.rrdbnet_arch import RRDBNet |
|
|
| from gfpgan import GFPGANer |
| from diffusers import StableDiffusionImg2ImgPipeline |
| import torch |
|
|
| app = FastAPI() |
|
|
| UPLOAD_DIR = "uploads" |
| OUTPUT_DIR = "outputs" |
| os.makedirs(UPLOAD_DIR, exist_ok=True) |
| os.makedirs(OUTPUT_DIR, exist_ok=True) |
|
|
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| |
|
|
| |
| |
| model = RRDBNet( |
| num_in_ch=3, |
| num_out_ch=3, |
| num_feat=64, |
| num_block=23, |
| num_grow_ch=32, |
| scale=4 |
| ) |
|
|
| upscaler = RealESRGANer( |
| scale=4, |
| model_path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth", |
| model=model, |
| tile=0, |
| tile_pad=10, |
| pre_pad=0, |
| half=(DEVICE == "cuda"), |
| device=DEVICE |
| ) |
|
|
| |
|
|
| |
| gfpgan = GFPGANer( |
| model_path="https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/GFPGANv1.4.pth", |
| upscale=1, |
| arch="clean", |
| channel_multiplier=2, |
| device=DEVICE |
| ) |
|
|
|
|
|
|
| |
| dtype = torch.float16 if DEVICE == "cuda" else torch.float32 |
|
|
| sd_pipe = StableDiffusionImg2ImgPipeline.from_pretrained( |
| "runwayml/stable-diffusion-v1-5", |
| torch_dtype=dtype |
| ).to(DEVICE) |
|
|
|
|
| |
|
|
| @app.post("/remove-bg") |
| async def remove_bg(file: UploadFile = File(...)): |
| img = Image.open(file.file).convert("RGBA") |
| out = remove(img) |
|
|
| fname = f"{uuid.uuid4().hex}.png" |
| path = os.path.join(OUTPUT_DIR, fname) |
| out.save(path) |
|
|
| return {"file": fname} |
|
|
|
|
| @app.post("/enhance") |
| async def enhance(file: UploadFile = File(...)): |
| img = Image.open(file.file).convert("RGB") |
| out, _ = upscaler.enhance(img) |
|
|
| fname = f"{uuid.uuid4().hex}.png" |
| out.save(os.path.join(OUTPUT_DIR, fname)) |
| return {"file": fname} |
|
|
|
|
| @app.post("/upscale") |
| async def upscale(file: UploadFile = File(...), scale: int = Form(2)): |
| img = Image.open(file.file).convert("RGB") |
| upscaler.scale = scale |
| out, _ = upscaler.enhance(img, outscale=scale) |
|
|
| fname = f"{uuid.uuid4().hex}.png" |
| out.save(os.path.join(OUTPUT_DIR, fname)) |
| return {"file": fname} |
|
|
|
|
| @app.post("/retouch") |
| async def retouch(file: UploadFile = File(...)): |
| img = Image.open(file.file).convert("RGB") |
| _, _, out = gfpgan.enhance( |
| img, |
| has_aligned=False, |
| only_center_face=False, |
| paste_back=True |
| ) |
|
|
| fname = f"{uuid.uuid4().hex}.png" |
| Image.fromarray(out).save(os.path.join(OUTPUT_DIR, fname)) |
| return {"file": fname} |
|
|
|
|
| @app.post("/edit") |
| async def edit_image( |
| file: UploadFile = File(...), |
| prompt: str = Form(...), |
| strength: float = Form(0.6) |
| ): |
| img = Image.open(file.file).convert("RGB").resize((512, 512)) |
|
|
| result = sd_pipe( |
| prompt=prompt, |
| image=img, |
| strength=strength, |
| guidance_scale=7.5 |
| ).images[0] |
|
|
| fname = f"{uuid.uuid4().hex}.png" |
| result.save(os.path.join(OUTPUT_DIR, fname)) |
| return { |
| "prompt": prompt, |
| "file": fname |
| } |
|
|