Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, File, UploadFile, Form | |
| from fastapi.responses import FileResponse, JSONResponse | |
| import uuid | |
| import os | |
| from PIL import Image | |
| import torch | |
| from diffusers import ( | |
| StableDiffusionControlNetPipeline, | |
| ControlNetModel, | |
| UniPCMultistepScheduler, | |
| StableDiffusionInpaintPipeline, | |
| StableDiffusionPipeline, | |
| ) | |
| from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection | |
| from ip_adapter.ip_adapter import IPAdapter | |
| import cv2 | |
| import numpy as np | |
| app = FastAPI() | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype = torch.float16 if device == "cuda" else torch.float32 | |
| SAVE_DIR = "/tmp/kitchen_ai" | |
| os.makedirs(SAVE_DIR, exist_ok=True) | |
| # Load ControlNet | |
| print("⏳ Loading ControlNet Canny...") | |
| try: | |
| controlnet = ControlNetModel.from_pretrained( | |
| "lllyasviel/sd-controlnet-canny", torch_dtype=dtype, cache_dir="/tmp/hf_models" | |
| ) | |
| pipe_canny = StableDiffusionControlNetPipeline.from_pretrained( | |
| "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=dtype, cache_dir="/tmp/hf_models" | |
| ).to(device) | |
| pipe_canny.scheduler = UniPCMultistepScheduler.from_config(pipe_canny.scheduler.config) | |
| print("✅ ControlNet loaded.") | |
| except Exception as e: | |
| pipe_canny = None | |
| print("❌ ControlNet failed:", e) | |
| # Load Inpainting (ĐÃ FIX) | |
| print("⏳ Loading Inpainting model...") | |
| try: | |
| pipe_inpaint = StableDiffusionInpaintPipeline.from_pretrained( | |
| "stabilityai/stable-diffusion-2-inpainting", # ✅ FIXED | |
| torch_dtype=dtype, | |
| cache_dir="/tmp/hf_models" | |
| ).to(device) | |
| print("✅ Inpainting model loaded.") | |
| except Exception as e: | |
| pipe_inpaint = None | |
| print("❌ Inpainting load failed:", e) | |
| # Load IP-Adapter | |
| print("⏳ Loading IP-Adapter...") | |
| try: | |
| base_pipe = StableDiffusionPipeline.from_pretrained( | |
| "runwayml/stable-diffusion-v1-5", torch_dtype=dtype, cache_dir="/tmp/hf_models" | |
| ).to(device) | |
| vision_model = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14") | |
| image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14") | |
| ip_adapter = IPAdapter(base_pipe, vision_model, image_processor, ip_ckpt="ip-adapter_sd15.safetensors") | |
| print("✅ IP-Adapter loaded.") | |
| except Exception as e: | |
| ip_adapter = None | |
| print("❌ IP-Adapter failed:", e) | |
| # Helper | |
| def prepare_canny(image_path): | |
| img = cv2.imread(image_path) | |
| img = cv2.resize(img, (768, 768)) # tăng từ 512 nếu muốn ảnh chi tiết hơn | |
| edge = cv2.Canny(img, 100, 200) | |
| edge = cv2.cvtColor(edge, cv2.COLOR_GRAY2RGB) | |
| return Image.fromarray(edge) | |
| # Routes | |
| async def transform_image(prompt: str = Form(...), image: UploadFile = File(...)): | |
| if pipe_canny is None: | |
| return JSONResponse({"error": "Model not available"}, status_code=500) | |
| input_path = os.path.join(SAVE_DIR, f"input_{uuid.uuid4().hex}.png") | |
| output_path = os.path.join(SAVE_DIR, f"output_{uuid.uuid4().hex}.png") | |
| with open(input_path, "wb") as f: | |
| f.write(await image.read()) | |
| control_image = prepare_canny(input_path) | |
| result = pipe_canny(prompt=prompt, image=control_image, num_inference_steps=25).images[0] | |
| result.save(output_path) | |
| os.remove(input_path) | |
| return JSONResponse({"image_url": f"/download/{os.path.basename(output_path)}"}) | |
| async def transform_inpaint(prompt: str = Form(...), image: UploadFile = File(...), mask: UploadFile = File(...)): | |
| if pipe_inpaint is None: | |
| return JSONResponse({"error": "Inpaint model not ready"}, status_code=500) | |
| input_path = os.path.join(SAVE_DIR, f"inpaint_input_{uuid.uuid4().hex}.png") | |
| mask_path = os.path.join(SAVE_DIR, f"mask_{uuid.uuid4().hex}.png") | |
| output_path = os.path.join(SAVE_DIR, f"inpaint_output_{uuid.uuid4().hex}.png") | |
| with open(input_path, "wb") as f: | |
| f.write(await image.read()) | |
| with open(mask_path, "wb") as f: | |
| f.write(await mask.read()) | |
| init_image = Image.open(input_path).convert("RGB").resize((512, 512)) | |
| mask_image = Image.open(mask_path).convert("L").resize((512, 512)) | |
| result = pipe_inpaint(prompt=prompt, image=init_image, mask_image=mask_image).images[0] | |
| result.save(output_path) | |
| os.remove(input_path) | |
| os.remove(mask_path) | |
| return JSONResponse({"image_url": f"/download/{os.path.basename(output_path)}"}) | |
| async def transform_ref(prompt: str = Form(...), image: UploadFile = File(...), ref_image: UploadFile = File(...)): | |
| if ip_adapter is None: | |
| return JSONResponse({"error": "IP-Adapter not ready"}, status_code=500) | |
| input_path = os.path.join(SAVE_DIR, f"ref_input_{uuid.uuid4().hex}.png") | |
| ref_path = os.path.join(SAVE_DIR, f"ref_img_{uuid.uuid4().hex}.png") | |
| output_path = os.path.join(SAVE_DIR, f"ref_output_{uuid.uuid4().hex}.png") | |
| with open(input_path, "wb") as f: | |
| f.write(await image.read()) | |
| with open(ref_path, "wb") as f: | |
| f.write(await ref_image.read()) | |
| pil_image = Image.open(input_path).convert("RGB").resize((512, 512)) | |
| ref_image_pil = Image.open(ref_path).convert("RGB").resize((224, 224)) | |
| images = ip_adapter.generate( | |
| pil_image=pil_image, | |
| ref_image=ref_image_pil, | |
| prompt=prompt, | |
| scale=0.6, | |
| seed=42 | |
| ) | |
| images[0].save(output_path) | |
| return JSONResponse({ | |
| "image_url": f"/download/{os.path.basename(output_path)}", | |
| "status": "success" | |
| }) | |
| async def get_image(filename: str): | |
| file_path = os.path.join(SAVE_DIR, filename) | |
| if not os.path.exists(file_path): | |
| return JSONResponse({"error": "File not found"}, status_code=404) | |
| return FileResponse(file_path, media_type="image/png", filename=filename) | |