from fastapi import FastAPI, UploadFile, File, Form import uuid, os from PIL import Image from rembg import remove from realesrgan import RealESRGANer from basicsr.archs.rrdbnet_arch import RRDBNet from gfpgan import GFPGANer from diffusers import StableDiffusionImg2ImgPipeline import torch app = FastAPI() UPLOAD_DIR = "uploads" OUTPUT_DIR = "outputs" os.makedirs(UPLOAD_DIR, exist_ok=True) os.makedirs(OUTPUT_DIR, exist_ok=True) DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # ===== Load models ONCE ===== # Real-ESRGAN # ===== Real-ESRGAN ===== model = RRDBNet( num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4 ) upscaler = RealESRGANer( scale=4, model_path="https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth", model=model, tile=0, tile_pad=10, pre_pad=0, half=(DEVICE == "cuda"), device=DEVICE ) #upscaler.load_weights("RealESRGAN_x4.pth") # GFPGAN gfpgan = GFPGANer( model_path="https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/GFPGANv1.4.pth", upscale=1, arch="clean", channel_multiplier=2, device=DEVICE ) # Stable Diffusion Img2Img dtype = torch.float16 if DEVICE == "cuda" else torch.float32 sd_pipe = StableDiffusionImg2ImgPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", torch_dtype=dtype ).to(DEVICE) # ===== Endpoints ===== @app.post("/remove-bg") async def remove_bg(file: UploadFile = File(...)): img = Image.open(file.file).convert("RGBA") out = remove(img) fname = f"{uuid.uuid4().hex}.png" path = os.path.join(OUTPUT_DIR, fname) out.save(path) return {"file": fname} @app.post("/enhance") async def enhance(file: UploadFile = File(...)): img = Image.open(file.file).convert("RGB") out, _ = upscaler.enhance(img) fname = f"{uuid.uuid4().hex}.png" out.save(os.path.join(OUTPUT_DIR, fname)) return {"file": fname} @app.post("/upscale") async def upscale(file: UploadFile = File(...), scale: int = Form(2)): img = Image.open(file.file).convert("RGB") upscaler.scale = scale out, _ = upscaler.enhance(img, outscale=scale) fname = f"{uuid.uuid4().hex}.png" out.save(os.path.join(OUTPUT_DIR, fname)) return {"file": fname} @app.post("/retouch") async def retouch(file: UploadFile = File(...)): img = Image.open(file.file).convert("RGB") _, _, out = gfpgan.enhance( img, has_aligned=False, only_center_face=False, paste_back=True ) fname = f"{uuid.uuid4().hex}.png" Image.fromarray(out).save(os.path.join(OUTPUT_DIR, fname)) return {"file": fname} @app.post("/edit") async def edit_image( file: UploadFile = File(...), prompt: str = Form(...), strength: float = Form(0.6) ): img = Image.open(file.file).convert("RGB").resize((512, 512)) result = sd_pipe( prompt=prompt, image=img, strength=strength, guidance_scale=7.5 ).images[0] fname = f"{uuid.uuid4().hex}.png" result.save(os.path.join(OUTPUT_DIR, fname)) return { "prompt": prompt, "file": fname }