Spaces:
Running on Zero
Running on Zero
| import spaces | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import cv2 | |
| import hashlib | |
| import os | |
| from PIL import Image | |
| from diffusers import FluxPipeline | |
| from insightface.app import FaceAnalysis | |
| from insightface.model_zoo import get_model | |
| # --- GLOBAL CONFIG --- | |
| MODEL_ID = "black-forest-labs/FLUX.1-dev" | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| # Initialize models as None for ZeroGPU lazy loading | |
| face_app = None | |
| swapper = None | |
| pipe = None | |
| def load_models_on_gpu(): | |
| """Initializes models only when GPU is allocated.""" | |
| global face_app, swapper, pipe | |
| if face_app is None: | |
| face_app = FaceAnalysis(name="buffalo_l", providers=['CPUExecutionProvider']) | |
| face_app.prepare(ctx_id=0, det_size=(640, 640)) | |
| if swapper is None: | |
| model_file = 'inswapper_128.onnx' | |
| if os.path.exists(model_file): | |
| swapper = get_model(model_file, providers=['CPUExecutionProvider']) | |
| if pipe is None: | |
| pipe = FluxPipeline.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.bfloat16, | |
| token=HF_TOKEN | |
| ) | |
| pipe.enable_model_cpu_offload() | |
| def upscale_image(image): | |
| img = np.array(image) | |
| w, h = image.size | |
| upscaled = cv2.resize(img, (w*2, h*2), interpolation=cv2.INTER_LANCZOS4) | |
| gaussian_blur = cv2.GaussianBlur(upscaled, (0, 0), 3) | |
| sharpened = cv2.addWeighted(upscaled, 1.5, gaussian_blur, -0.5, 0) | |
| return Image.fromarray(sharpened) | |
| def generate_vton_final(face_image, body_type, height_ft): | |
| if face_image is None: | |
| return None, "Please upload a face image." | |
| # Ensure models are loaded in the GPU context | |
| load_models_on_gpu() | |
| # 1. Face Analysis | |
| img_np = np.array(face_image) | |
| cv_img = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) | |
| faces = face_app.get(cv_img) | |
| if not faces: | |
| return None, "No face detected in the upload." | |
| source_face = faces[0] | |
| gender = "man" if source_face.gender == 1 else "woman" | |
| # 2. Simplified Prompt (Normal Pose & Casual Clothes) | |
| profile_seed = int(hashlib.md5(f"{gender}-{body_type}".encode()).hexdigest(), 16) % (10**9) | |
| generator = torch.Generator("cuda").manual_seed(profile_seed) | |
| prompt = ( | |
| f"A full body 8k professional photo of a {gender}, {body_type} build, {height_ft}ft tall. " | |
| f"Standing in a relaxed, natural pose, facing the camera. " | |
| f"Wearing stylish casual clothing, clean studio background, sharp focus, cinematic lighting." | |
| ) | |
| # 3. Generation | |
| gen_img = pipe( | |
| prompt=prompt, | |
| height=1024, width=768, | |
| guidance_scale=3.5, | |
| num_inference_steps=28, | |
| generator=generator | |
| ).images[0] | |
| # 4. Face Swap | |
| if swapper: | |
| res_np = np.array(gen_img) | |
| res_cv = cv2.cvtColor(res_np, cv2.COLOR_RGB2BGR) | |
| target_faces = face_app.get(res_cv) | |
| if target_faces: | |
| # Sort to find the main person in the photo | |
| target_faces = sorted(target_faces, key=lambda x: (x.bbox[2]-x.bbox[0])*(x.bbox[3]-x.bbox[1]), reverse=True) | |
| res_cv = swapper.get(res_cv, target_faces[0], source_face, paste_back=True) | |
| gen_img = Image.fromarray(cv2.cvtColor(res_cv, cv2.COLOR_BGR2RGB)) | |
| # 5. HD Upscale | |
| return upscale_image(gen_img), f"Success | Seed: {profile_seed}" | |
| # --- GRADIO INTERFACE --- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# 💎 AI Virtual Model Engine") | |
| with gr.Row(): | |
| with gr.Column(): | |
| face_in = gr.Image(type="pil", label="Step 1: Upload Face") | |
| body_in = gr.Radio(["slim", "muscular", "average"], value="average", label="Step 2: Body Build") | |
| h_in = gr.Slider(4.5, 7.0, value=5.8, step=0.1, label="Step 3: Height (ft)") | |
| btn = gr.Button("Generate High-Res Model", variant="primary") | |
| with gr.Column(): | |
| img_out = gr.Image(label="Final Result") | |
| status = gr.Textbox(label="Logs") | |
| btn.click(generate_vton_final, [face_in, body_in, h_in], [img_out, status]) | |
| demo.launch() |