Spaces:
Running
on
Zero
Running
on
Zero
| # app.py — ZeroGPU対応版 | |
| import gradio as gr | |
| import spaces | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| import os | |
| import subprocess | |
| import traceback | |
| import base64 | |
| import io | |
| from pathlib import Path | |
| # FastAPI関連(ハイブリッド構成のため維持) | |
| from fastapi import FastAPI, UploadFile, File, Form, HTTPException | |
| # グローバル変数としてパイプラインを定義(初期値はNone) | |
| pipe = None | |
| face_app = None | |
| upsampler = None | |
| UPSCALE_OK = False | |
| # 0. Cache dir & helpers (起動時に実行) | |
| PERSIST_BASE = Path("/data") | |
| CACHE_ROOT = (PERSIST_BASE / "instantid_cache" if PERSIST_BASE.exists() and os.access(PERSIST_BASE, os.W_OK) | |
| else Path.home() / ".cache" / "instantid_cache") | |
| MODELS_DIR, LORA_DIR, EMB_DIR, UPSCALE_DIR = CACHE_ROOT/"models", CACHE_ROOT/"models"/"Lora", CACHE_ROOT/"embeddings", CACHE_ROOT/"realesrgan" | |
| for p in (MODELS_DIR, LORA_DIR, EMB_DIR, UPSCALE_DIR): | |
| p.mkdir(parents=True, exist_ok=True) | |
| def dl(url: str, dst: Path, attempts: int = 2): | |
| if dst.exists(): return | |
| for i in range(1, attempts + 1): | |
| print(f"⬇ Downloading {dst.name} (try {i}/{attempts})") | |
| if subprocess.call(["wget", "-q", "-O", str(dst), url]) == 0: return | |
| raise RuntimeError(f"download failed → {url}") | |
| # 1. Asset download (起動時に実行) | |
| print("— Starting asset download check —") | |
| BASE_CKPT = MODELS_DIR / "beautiful_realistic_asians_v7_fp16.safetensors" | |
| dl("https://civitai.com/api/download/models/177164?type=Model&format=SafeTensor&size=pruned&fp=fp16", BASE_CKPT) | |
| IP_BIN_FILE = LORA_DIR / "ip-adapter-plus-face_sd15.bin" | |
| dl("https://huggingface.co/h94/IP-Adapter/resolve/main/models/ip-adapter-plus-face_sd15.bin", IP_BIN_FILE) | |
| LORA_FILE = LORA_DIR / "ip-adapter-faceid-plusv2_sd15_lora.safetensors" | |
| dl("https://huggingface.co/h94/IP-Adapter-FaceID/resolve/main/ip-adapter-faceid-plusv2_sd15_lora.safetensors", LORA_FILE) | |
| print("— Asset download check finished —") | |
| # 2. パイプライン初期化関数 (GPU確保後に呼び出される) | |
| def initialize_pipelines(): | |
| global pipe, face_app, upsampler, UPSCALE_OK | |
| # torch/diffusers/onnxruntimeなどのインポートを関数内に移動 | |
| from diffusers import StableDiffusionPipeline, ControlNetModel, DPMSolverMultistepScheduler, AutoencoderKL | |
| from insightface.app import FaceAnalysis | |
| print("--- Initializing Pipelines (GPU is now available) ---") | |
| device = torch.device("cuda") # ZeroGPUではGPUが保証されている | |
| dtype = torch.float16 | |
| # FaceAnalysis | |
| if face_app is None: | |
| print("Initializing FaceAnalysis...") | |
| providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] | |
| face_app = FaceAnalysis(name="buffalo_l", root=str(CACHE_ROOT), providers=providers) | |
| face_app.prepare(ctx_id=0, det_size=(640, 640)) | |
| print("FaceAnalysis initialized.") | |
| # Main Pipeline | |
| if pipe is None: | |
| print("Loading ControlNet...") | |
| controlnet = ControlNetModel.from_pretrained("InstantX/InstantID", subfolder="ControlNetModel", torch_dtype=dtype) | |
| print("Loading StableDiffusionPipeline...") | |
| pipe = StableDiffusionPipeline.from_single_file(BASE_CKPT, torch_dtype=dtype, safety_checker=None, use_safetensors=True, clip_skip=2) | |
| print("Moving pipeline to GPU...") | |
| pipe.to(device) # .to(device)をここで呼ぶ | |
| print("Loading VAE...") | |
| pipe.vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=dtype).to(device) | |
| pipe.controlnet = controlnet | |
| print("Configuring Scheduler...") | |
| pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True, algorithm_type="sde-dpmsolver++") | |
| print("Loading IP-Adapter and LoRA...") | |
| pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name=IP_BIN_FILE.name) | |
| pipe.load_lora_weights(str(LORA_DIR), weight_name=LORA_FILE.name) | |
| pipe.set_ip_adapter_scale(0.65) | |
| print("Main pipeline initialized.") | |
| # Upscaler | |
| if upsampler is None and not UPSCALE_OK: # 一度失敗したら再試行しない | |
| print("Checking for Upscaler...") | |
| try: | |
| from basicsr.archs.rrdb_arch import RRDBNet | |
| from realesrgan import RealESRGAN | |
| rrdb = RRDBNet(3, 3, 64, 23, 32, scale=8) | |
| upsampler = RealESRGAN(device, rrdb, scale=8) | |
| upsampler.load_weights(str(UPSCALE_DIR / "RealESRGAN_x8plus.pth")) | |
| UPSCALE_OK = True | |
| print("Upscaler initialized successfully.") | |
| except Exception as e: | |
| UPSCALE_OK = False # 失敗を記録 | |
| print(f"Real-ESRGAN disabled → {e}") | |
| print("--- All pipelines ready ---") | |
| # 4. Core generation logic | |
| BASE_PROMPT = ("(masterpiece:1.2), best quality, ultra-realistic, RAW photo, 8k,\n""photo of {subject},\n""cinematic lighting, golden hour, rim light, shallow depth of field,\n""textured skin, high detail, shot on Canon EOS R5, 85 mm f/1.4, ISO 200,\n""<lora:ip-adapter-faceid-plusv2_sd15_lora:0.65>, (face),\n""(aesthetic:1.1), (cinematic:0.8)") | |
| NEG_PROMPT = ("ng_deepnegative_v1_75t, CyberRealistic_Negative-neg, UnrealisticDream, ""(worst quality:2), (low quality:1.8), lowres, (jpeg artifacts:1.2), ""painting, sketch, illustration, drawing, cartoon, anime, cgi, render, 3d, ""monochrome, grayscale, text, logo, watermark, signature, username, ""(MajicNegative_V2:0.8), bad hands, extra digits, fused fingers, malformed limbs, ""missing arms, missing legs, (badhandv4:0.7), BadNegAnatomyV1-neg, skin blemishes, acnes, age spot, glans") | |
| # ZeroGPUで実行される本体。durationを60秒に設定。 | |
| def _generate_core(face_img, subject, add_prompt, add_neg, cfg, ip_scale, steps, w, h, upscale, up_factor, progress=gr.Progress(track_tqdm=True)): | |
| # 初回呼び出し時にパイプラインを初期化 | |
| initialize_pipelines() | |
| progress(0, desc="Generating image...") | |
| prompt = BASE_PROMPT.format(subject=(subject.strip() or "a beautiful 20yo woman")) | |
| if add_prompt: prompt += ", " + add_prompt | |
| neg = NEG_PROMPT + (", " + add_neg if add_neg else "") | |
| pipe.set_ip_adapter_scale(ip_scale) | |
| result = pipe(prompt=prompt, negative_prompt=neg, ip_adapter_image=face_img, image=face_img, controlnet_conditioning_scale=0.9, num_inference_steps=int(steps) + 5, guidance_scale=cfg, width=int(w), height=int(h)).images[0] | |
| if upscale and UPSCALE_OK: | |
| progress(0.8, desc="Upscaling...") | |
| up, _ = upsampler.enhance(cv2.cvtColor(np.array(result), cv2.COLOR_RGB2BGR), outscale=up_factor) | |
| result = Image.fromarray(cv2.cvtColor(up, cv2.COLOR_BGR2RGB)) | |
| return result | |
| # GradioのUIから呼び出されるラッパー関数 | |
| def generate_ui(face_np, subject, add_prompt, add_neg, cfg, ip_scale, steps, w, h, upscale, up_factor, progress=gr.Progress(track_tqdm=True)): | |
| if face_np is None: raise gr.Error("顔画像をアップロードしてください。") | |
| # NumPy配列をPillow画像に変換 | |
| face_img = Image.fromarray(face_np) | |
| return _generate_core(face_img, subject, add_prompt, add_neg, cfg, ip_scale, steps, w, h, upscale, up_factor, progress) | |
| # 5. Gradio UI Definition | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# InstantID – Beautiful Realistic Asians v7 (ZeroGPU)") | |
| with gr.Row(): | |
| with gr.Column(): | |
| face_in = gr.Image(label="顔写真",type="numpy") | |
| subj_in = gr.Textbox(label="被写体説明",placeholder="e.g. woman in black suit, smiling") | |
| add_in = gr.Textbox(label="追加プロンプト") | |
| addneg_in = gr.Textbox(label="追加ネガティブ") | |
| with gr.Accordion("詳細設定", open=False): | |
| ip_sld = gr.Slider(0,1.5,0.65,step=0.05,label="IP‑Adapter scale") | |
| cfg_sld = gr.Slider(1,15,6,step=0.5,label="CFG") | |
| step_sld = gr.Slider(10,50,20,step=1,label="Steps") | |
| w_sld = gr.Slider(512,1024,512,step=64,label="幅") | |
| h_sld = gr.Slider(512,1024,768,step=64,label="高さ") | |
| up_ck = gr.Checkbox(label="アップスケール",value=True) | |
| up_fac = gr.Slider(1,8,2,step=1,label="倍率") | |
| btn = gr.Button("生成",variant="primary") | |
| with gr.Column(): | |
| out_img = gr.Image(label="結果") | |
| # .queue() はGradioの通常機能として必要 | |
| demo.queue() | |
| btn.click( | |
| fn=generate_ui, | |
| inputs=[face_in,subj_in,add_in,addneg_in,cfg_sld,ip_sld,step_sld,w_sld,h_sld,up_ck,up_fac], | |
| outputs=out_img | |
| ) | |
| # 6. FastAPI Mounting | |
| app = FastAPI() | |
| # FastAPIのエンドポイントを定義。こちらも内部で_generate_coreを呼ぶ | |
| async def predict_endpoint( | |
| face_image: UploadFile = File(...), | |
| subject: str = Form("a woman"), | |
| add_prompt: str = Form(""), | |
| add_neg: str = Form(""), | |
| cfg: float = Form(6.0), | |
| ip_scale: float = Form(0.65), | |
| steps: int = Form(20), | |
| w: int = Form(512), | |
| h: int = Form(768), | |
| upscale: bool = Form(True), | |
| up_factor: float = Form(2.0) | |
| ): | |
| try: | |
| contents = await face_image.read() | |
| pil_image = Image.open(io.BytesIO(contents)) | |
| # FastAPI経由の呼び出しも同じコア関数を利用 | |
| result_pil_image = _generate_core( | |
| pil_image, subject, add_prompt, add_neg, cfg, ip_scale, | |
| steps, w, h, upscale, up_factor | |
| ) | |
| buffered = io.BytesIO() | |
| result_pil_image.save(buffered, format="PNG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
| return {"image_base64": img_str} | |
| except Exception as e: | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # GradioアプリをFastAPIアプリにマウント | |
| app = gr.mount_gradio_app(app, demo, path="/") | |
| print("Application startup script finished. Waiting for requests.") |