File size: 2,981 Bytes
7edb31d
 
 
 
 
f07b0b4
 
 
 
7edb31d
 
 
9180ecd
7edb31d
 
 
9f4de68
7edb31d
 
 
9f4de68
7edb31d
 
 
9f4de68
7edb31d
 
9f4de68
7edb31d
 
 
9180ecd
7edb31d
 
 
67eaec9
 
 
9180ecd
9f4de68
7edb31d
 
 
9f4de68
 
 
7edb31d
67eaec9
f07b0b4
 
 
67eaec9
7edb31d
 
 
 
f07b0b4
7edb31d
 
9f4de68
f07b0b4
7edb31d
9f4de68
 
7edb31d
9f4de68
 
7edb31d
 
 
 
 
 
 
 
9f4de68
7edb31d
9f4de68
7edb31d
9180ecd
7edb31d
67eaec9
7edb31d
 
9180ecd
f07b0b4
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import os, io, base64, threading, traceback
import torch
import numpy as np
from PIL import Image, ImageDraw
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import HTMLResponse, JSONResponse
import uvicorn

app = FastAPI()
MODEL_LOADED = False
LOAD_ERROR   = ""
pipe         = None

def load_model():
    global pipe, MODEL_LOADED, LOAD_ERROR
    try:
        print("📥 Loading model on CPU...")
        from diffusers import StableDiffusionInpaintPipeline
        pipe = StableDiffusionInpaintPipeline.from_pretrained(
            "runwayml/stable-diffusion-inpainting",
            torch_dtype=torch.float32,   # CPU needs float32
            safety_checker=None,
            requires_safety_checker=False,
        )
        # CPU ONLY — no .to("cuda"), no cpu_offload
        pipe.enable_attention_slicing()
        MODEL_LOADED = True
        print("✅ Model ready on CPU!")
    except Exception as e:
        LOAD_ERROR = str(e)
        print(f"❌ {e}")

threading.Thread(target=load_model, daemon=True).start()

def pil_to_b64(img):
    buf = io.BytesIO()
    img.save(buf, format="PNG")
    return base64.b64encode(buf.getvalue()).decode()

def make_mask(size):
    w, h = size
    mask = Image.new("L", size, 0)
    draw = ImageDraw.Draw(mask)
    draw.rectangle([w*0.05, h*0.18, w*0.95, h*0.68], fill=255)
    draw.rectangle([w*0.0,  h*0.18, w*0.15, h*0.58], fill=255)
    draw.rectangle([w*0.85, h*0.18, w*1.0,  h*0.58], fill=255)
    return mask.convert("RGB")

@app.get("/", response_class=HTMLResponse)
async def index():
    return HTMLResponse(open("/app/index.html").read())

@app.get("/status")
async def status():
    return {"loaded": MODEL_LOADED, "error": LOAD_ERROR}

@app.post("/tryon")
async def tryon(person: UploadFile = File(...), garment: UploadFile = File(...)):
    if not MODEL_LOADED:
        return JSONResponse({"status":"loading","message":"Model still loading, please wait and retry."}, status_code=503)
    try:
        SIZE = (512, 768)
        person_img = Image.open(io.BytesIO(await person.read())).convert("RGB").resize(SIZE)
        mask_img   = make_mask(SIZE)

        prompt   = "Person wearing a clean stylish garment, photorealistic, high quality fashion photo, same pose, same background"
        negative = "nude, deformed, blurry, bad anatomy, extra limbs, watermark, logo, text, disfigured"

        result = pipe(
            prompt=prompt,
            negative_prompt=negative,
            image=person_img,
            mask_image=mask_img,
            height=SIZE[1],
            width=SIZE[0],
            num_inference_steps=25,
            guidance_scale=7.5,
            strength=0.95,
        ).images[0]

        return JSONResponse({"status":"ok","image": pil_to_b64(result)})
    except Exception as e:
        traceback.print_exc()
        return JSONResponse({"status":"error","message":str(e)}, status_code=500)

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)