from fastapi import FastAPI, File, UploadFile, Form from fastapi.responses import FileResponse, JSONResponse import uuid import os from PIL import Image import torch from diffusers import ( StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler, StableDiffusionInpaintPipeline, StableDiffusionPipeline, ) from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from ip_adapter.ip_adapter import IPAdapter import cv2 import numpy as np app = FastAPI() device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.float16 if device == "cuda" else torch.float32 SAVE_DIR = "/tmp/kitchen_ai" os.makedirs(SAVE_DIR, exist_ok=True) # Load ControlNet print("⏳ Loading ControlNet Canny...") try: controlnet = ControlNetModel.from_pretrained( "lllyasviel/sd-controlnet-canny", torch_dtype=dtype, cache_dir="/tmp/hf_models" ) pipe_canny = StableDiffusionControlNetPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=dtype, cache_dir="/tmp/hf_models" ).to(device) pipe_canny.scheduler = UniPCMultistepScheduler.from_config(pipe_canny.scheduler.config) print("✅ ControlNet loaded.") except Exception as e: pipe_canny = None print("❌ ControlNet failed:", e) # Load Inpainting (ĐÃ FIX) print("⏳ Loading Inpainting model...") try: pipe_inpaint = StableDiffusionInpaintPipeline.from_pretrained( "stabilityai/stable-diffusion-2-inpainting", # ✅ FIXED torch_dtype=dtype, cache_dir="/tmp/hf_models" ).to(device) print("✅ Inpainting model loaded.") except Exception as e: pipe_inpaint = None print("❌ Inpainting load failed:", e) # Load IP-Adapter print("⏳ Loading IP-Adapter...") try: base_pipe = StableDiffusionPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", torch_dtype=dtype, cache_dir="/tmp/hf_models" ).to(device) vision_model = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14") image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14") ip_adapter = IPAdapter(base_pipe, vision_model, image_processor, ip_ckpt="ip-adapter_sd15.safetensors") print("✅ IP-Adapter loaded.") except Exception as e: ip_adapter = None print("❌ IP-Adapter failed:", e) # Helper def prepare_canny(image_path): img = cv2.imread(image_path) img = cv2.resize(img, (768, 768)) # tăng từ 512 nếu muốn ảnh chi tiết hơn edge = cv2.Canny(img, 100, 200) edge = cv2.cvtColor(edge, cv2.COLOR_GRAY2RGB) return Image.fromarray(edge) # Routes @app.post("/transform/") async def transform_image(prompt: str = Form(...), image: UploadFile = File(...)): if pipe_canny is None: return JSONResponse({"error": "Model not available"}, status_code=500) input_path = os.path.join(SAVE_DIR, f"input_{uuid.uuid4().hex}.png") output_path = os.path.join(SAVE_DIR, f"output_{uuid.uuid4().hex}.png") with open(input_path, "wb") as f: f.write(await image.read()) control_image = prepare_canny(input_path) result = pipe_canny(prompt=prompt, image=control_image, num_inference_steps=25).images[0] result.save(output_path) os.remove(input_path) return JSONResponse({"image_url": f"/download/{os.path.basename(output_path)}"}) @app.post("/transform_inpaint/") async def transform_inpaint(prompt: str = Form(...), image: UploadFile = File(...), mask: UploadFile = File(...)): if pipe_inpaint is None: return JSONResponse({"error": "Inpaint model not ready"}, status_code=500) input_path = os.path.join(SAVE_DIR, f"inpaint_input_{uuid.uuid4().hex}.png") mask_path = os.path.join(SAVE_DIR, f"mask_{uuid.uuid4().hex}.png") output_path = os.path.join(SAVE_DIR, f"inpaint_output_{uuid.uuid4().hex}.png") with open(input_path, "wb") as f: f.write(await image.read()) with open(mask_path, "wb") as f: f.write(await mask.read()) init_image = Image.open(input_path).convert("RGB").resize((512, 512)) mask_image = Image.open(mask_path).convert("L").resize((512, 512)) result = pipe_inpaint(prompt=prompt, image=init_image, mask_image=mask_image).images[0] result.save(output_path) os.remove(input_path) os.remove(mask_path) return JSONResponse({"image_url": f"/download/{os.path.basename(output_path)}"}) @app.post("/transform_ref/") async def transform_ref(prompt: str = Form(...), image: UploadFile = File(...), ref_image: UploadFile = File(...)): if ip_adapter is None: return JSONResponse({"error": "IP-Adapter not ready"}, status_code=500) input_path = os.path.join(SAVE_DIR, f"ref_input_{uuid.uuid4().hex}.png") ref_path = os.path.join(SAVE_DIR, f"ref_img_{uuid.uuid4().hex}.png") output_path = os.path.join(SAVE_DIR, f"ref_output_{uuid.uuid4().hex}.png") with open(input_path, "wb") as f: f.write(await image.read()) with open(ref_path, "wb") as f: f.write(await ref_image.read()) pil_image = Image.open(input_path).convert("RGB").resize((512, 512)) ref_image_pil = Image.open(ref_path).convert("RGB").resize((224, 224)) images = ip_adapter.generate( pil_image=pil_image, ref_image=ref_image_pil, prompt=prompt, scale=0.6, seed=42 ) images[0].save(output_path) return JSONResponse({ "image_url": f"/download/{os.path.basename(output_path)}", "status": "success" }) @app.get("/download/{filename}") async def get_image(filename: str): file_path = os.path.join(SAVE_DIR, filename) if not os.path.exists(file_path): return JSONResponse({"error": "File not found"}, status_code=404) return FileResponse(file_path, media_type="image/png", filename=filename)