ginhang2209's picture
Update app.py
2401f91 verified
from fastapi import FastAPI, File, UploadFile, Form
from fastapi.responses import FileResponse, JSONResponse
import uuid
import os
from PIL import Image
import torch
from diffusers import (
StableDiffusionControlNetPipeline,
ControlNetModel,
UniPCMultistepScheduler,
StableDiffusionInpaintPipeline,
StableDiffusionPipeline,
)
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from ip_adapter.ip_adapter import IPAdapter
import cv2
import numpy as np
app = FastAPI()
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
SAVE_DIR = "/tmp/kitchen_ai"
os.makedirs(SAVE_DIR, exist_ok=True)
# Load ControlNet
print("⏳ Loading ControlNet Canny...")
try:
controlnet = ControlNetModel.from_pretrained(
"lllyasviel/sd-controlnet-canny", torch_dtype=dtype, cache_dir="/tmp/hf_models"
)
pipe_canny = StableDiffusionControlNetPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=dtype, cache_dir="/tmp/hf_models"
).to(device)
pipe_canny.scheduler = UniPCMultistepScheduler.from_config(pipe_canny.scheduler.config)
print("✅ ControlNet loaded.")
except Exception as e:
pipe_canny = None
print("❌ ControlNet failed:", e)
# Load Inpainting (ĐÃ FIX)
print("⏳ Loading Inpainting model...")
try:
pipe_inpaint = StableDiffusionInpaintPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-inpainting", # ✅ FIXED
torch_dtype=dtype,
cache_dir="/tmp/hf_models"
).to(device)
print("✅ Inpainting model loaded.")
except Exception as e:
pipe_inpaint = None
print("❌ Inpainting load failed:", e)
# Load IP-Adapter
print("⏳ Loading IP-Adapter...")
try:
base_pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", torch_dtype=dtype, cache_dir="/tmp/hf_models"
).to(device)
vision_model = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14")
ip_adapter = IPAdapter(base_pipe, vision_model, image_processor, ip_ckpt="ip-adapter_sd15.safetensors")
print("✅ IP-Adapter loaded.")
except Exception as e:
ip_adapter = None
print("❌ IP-Adapter failed:", e)
# Helper
def prepare_canny(image_path):
img = cv2.imread(image_path)
img = cv2.resize(img, (768, 768)) # tăng từ 512 nếu muốn ảnh chi tiết hơn
edge = cv2.Canny(img, 100, 200)
edge = cv2.cvtColor(edge, cv2.COLOR_GRAY2RGB)
return Image.fromarray(edge)
# Routes
@app.post("/transform/")
async def transform_image(prompt: str = Form(...), image: UploadFile = File(...)):
if pipe_canny is None:
return JSONResponse({"error": "Model not available"}, status_code=500)
input_path = os.path.join(SAVE_DIR, f"input_{uuid.uuid4().hex}.png")
output_path = os.path.join(SAVE_DIR, f"output_{uuid.uuid4().hex}.png")
with open(input_path, "wb") as f:
f.write(await image.read())
control_image = prepare_canny(input_path)
result = pipe_canny(prompt=prompt, image=control_image, num_inference_steps=25).images[0]
result.save(output_path)
os.remove(input_path)
return JSONResponse({"image_url": f"/download/{os.path.basename(output_path)}"})
@app.post("/transform_inpaint/")
async def transform_inpaint(prompt: str = Form(...), image: UploadFile = File(...), mask: UploadFile = File(...)):
if pipe_inpaint is None:
return JSONResponse({"error": "Inpaint model not ready"}, status_code=500)
input_path = os.path.join(SAVE_DIR, f"inpaint_input_{uuid.uuid4().hex}.png")
mask_path = os.path.join(SAVE_DIR, f"mask_{uuid.uuid4().hex}.png")
output_path = os.path.join(SAVE_DIR, f"inpaint_output_{uuid.uuid4().hex}.png")
with open(input_path, "wb") as f:
f.write(await image.read())
with open(mask_path, "wb") as f:
f.write(await mask.read())
init_image = Image.open(input_path).convert("RGB").resize((512, 512))
mask_image = Image.open(mask_path).convert("L").resize((512, 512))
result = pipe_inpaint(prompt=prompt, image=init_image, mask_image=mask_image).images[0]
result.save(output_path)
os.remove(input_path)
os.remove(mask_path)
return JSONResponse({"image_url": f"/download/{os.path.basename(output_path)}"})
@app.post("/transform_ref/")
async def transform_ref(prompt: str = Form(...), image: UploadFile = File(...), ref_image: UploadFile = File(...)):
if ip_adapter is None:
return JSONResponse({"error": "IP-Adapter not ready"}, status_code=500)
input_path = os.path.join(SAVE_DIR, f"ref_input_{uuid.uuid4().hex}.png")
ref_path = os.path.join(SAVE_DIR, f"ref_img_{uuid.uuid4().hex}.png")
output_path = os.path.join(SAVE_DIR, f"ref_output_{uuid.uuid4().hex}.png")
with open(input_path, "wb") as f:
f.write(await image.read())
with open(ref_path, "wb") as f:
f.write(await ref_image.read())
pil_image = Image.open(input_path).convert("RGB").resize((512, 512))
ref_image_pil = Image.open(ref_path).convert("RGB").resize((224, 224))
images = ip_adapter.generate(
pil_image=pil_image,
ref_image=ref_image_pil,
prompt=prompt,
scale=0.6,
seed=42
)
images[0].save(output_path)
return JSONResponse({
"image_url": f"/download/{os.path.basename(output_path)}",
"status": "success"
})
@app.get("/download/{filename}")
async def get_image(filename: str):
file_path = os.path.join(SAVE_DIR, filename)
if not os.path.exists(file_path):
return JSONResponse({"error": "File not found"}, status_code=404)
return FileResponse(file_path, media_type="image/png", filename=filename)