""" FlowInOne - HuggingFace Space Demo Unifying Multimodal Generation as Image-In Image-Out Flow Matching """ import os import sys import traceback import tempfile import numpy as np import torch import einops import gradio as gr from PIL import Image import spaces from huggingface_hub import hf_hub_download, snapshot_download sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) # ── Paths & Setup ──────────────────────────────────────────────────────────── IMAGE_SIZE = 256 DEVICE = "cuda" if torch.cuda.is_available() else "cpu" CONFIG_PATH = os.environ.get("CONFIG_PATH", "configs/flowinone_training_demo.py") print("Đang tải các models từ Hugging Face Hub...") NNET_PATH = hf_hub_download(repo_id="CSU-JPG/FlowInOne", filename="flowinone_256px.pth") JANUS_MODEL_PATH = snapshot_download(repo_id="deepseek-ai/Janus-Pro-1B") # ── Helpers ─────────────────────────────────────────────────────────────────── def unpreprocess(tensor): tensor = (tensor.clamp(-1, 1) + 1.0) / 2.0 return tensor def center_crop_arr(pil_image, image_size): while min(*pil_image.size) >= 2 * image_size: pil_image = pil_image.resize( tuple(x // 2 for x in pil_image.size), resample=Image.BOX ) scale = image_size / min(*pil_image.size) pil_image = pil_image.resize( tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC ) arr = np.array(pil_image) crop_y = (arr.shape[0] - image_size) // 2 crop_x = (arr.shape[1] - image_size) // 2 return arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size] def pil_to_tensor(pil_image, image_size, device): arr = center_crop_arr(pil_image.convert("RGB"), image_size) arr = (arr / 127.5 - 1.0).astype(np.float32) t = torch.from_numpy(einops.rearrange(arr, "h w c -> c h w")).to(device) return t.unsqueeze(0) def tensor_to_pil(arr): arr = unpreprocess(arr) arr = arr.detach().cpu().numpy() arr = (arr * 255).astype(np.uint8) arr = einops.rearrange(arr, "c h w -> h w c") return Image.fromarray(arr) # ── Model loading ───────────────────────────────────────────────────────────── print("Loading config...") import ml_collections spec = __import__("importlib").util.spec_from_file_location("cfg", CONFIG_PATH) cfg_module = __import__("importlib").util.module_from_spec(spec) spec.loader.exec_module(cfg_module) _config = cfg_module.get_config() print("Loading VAE...") import libs.autoencoder as autoencoder_lib vae_weight_path = hf_hub_download(repo_id="stabilityai/sd-vae-ft-mse-original", filename="vae-ft-mse-840000-ema-pruned.ckpt") fixed_vae_path = os.path.join(tempfile.gettempdir(), "vae_extracted_v2.pth") if not os.path.exists(fixed_vae_path): ckpt = torch.load(vae_weight_path, map_location="cpu") real_state_dict = ckpt["state_dict"] if "state_dict" in ckpt else ckpt for k in ["model_ema.decay", "model_ema.num_updates"]: real_state_dict.pop(k, None) torch.save(real_state_dict, fixed_vae_path) _config.autoencoder.pretrained_path = fixed_vae_path _autoencoder = autoencoder_lib.get_model(**_config.autoencoder).to(DEVICE).eval() print("Loading NNet...") import utils _nnet = utils.get_nnet(**_config.nnet) _nnet.load_state_dict(torch.load(NNET_PATH, map_location="cpu")) _nnet.to(DEVICE).eval() print("Loading Janus-Pro-1B...") from libs.janus.models import MultiModalityCausalLM, VLChatProcessor from transformers import AutoModelForCausalLM _vl_chat_processor = VLChatProcessor.from_pretrained(JANUS_MODEL_PATH) _vl_gpt = AutoModelForCausalLM.from_pretrained( JANUS_MODEL_PATH, trust_remote_code=True, use_safetensors=False ).half().to(DEVICE).eval() # ── Inference ───────────────────────────────────────────────────────────────── @spaces.GPU def run_inference(pil_input, text_prompt, cfg_scale, sample_steps, skip_cross_atten): try: from diffusion.flow_matching import ODEEulerFlowMatchingSolver import utils # 1. Tiền xử lý ảnh (chỉ lấy ảnh đã crop, không chèn chữ) input_tensor = pil_to_tensor(pil_input, IMAGE_SIZE, DEVICE) arr = center_crop_arr(pil_input.convert("RGB"), IMAGE_SIZE) cropped_pil = Image.fromarray(arr) # 2. Truyền thẳng Text Prompt vào VLM (Đúng chuẩn Multimodal) question = text_prompt.strip() if text_prompt else "" sft_format = _vl_chat_processor.apply_sft_template_for_multi_turn_prompts( conversations=[ {"role": "<|User|>", "content": f"\n{question}"}, {"role": "<|Assistant|>", "content": ""}, ], sft_format=_vl_chat_processor.sft_format, system_prompt=_vl_chat_processor.system_prompt, ) cached_input_ids = _vl_chat_processor.tokenizer.encode(sft_format) with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp: cropped_pil.save(tmp.name) tmp_path = tmp.name use_cross_atten_mask = torch.tensor([skip_cross_atten], dtype=torch.bool, device=DEVICE) contexts, token_mask = utils.get_input_image_embeddings_and_masks( batch_input_images=[tmp_path], vl_chat_processor=_vl_chat_processor, vl_gpt=_vl_gpt, device=DEVICE, question=question, # <--- Truyền câu hỏi vào hàm context generator num_image_tokens=576, output_tokens=576, accelerator=None, cached_input_ids=cached_input_ids, ) with torch.no_grad(): input_moments = _autoencoder(input_tensor, fn="encode_moments") input_latent = _autoencoder.sample(input_moments) z_gaussian = torch.randn(1, *_config.z_shape, device=DEVICE) with torch.cuda.amp.autocast(enabled=(DEVICE == "cuda")): z_x0, _, _ = _nnet( contexts, text_encoder=True, shape=z_gaussian.shape, mask=token_mask, use_cross_atten_mask=use_cross_atten_mask, ) z_init = z_x0.reshape(z_gaussian.shape) ode_solver = ODEEulerFlowMatchingSolver( _nnet, bdv_model_fn=None, step_size_type="step_in_dsigma", guidance_scale=cfg_scale, ) z, _ = ode_solver.sample( x_T=z_init, batch_size=1, sample_steps=sample_steps, unconditional_guidance_scale=cfg_scale, has_null_indicator=hasattr(_config.nnet.model_args, "cfg_indicator"), image_latent=input_latent, use_cross_atten_mask=use_cross_atten_mask, ) with torch.cuda.amp.autocast(enabled=(DEVICE == "cuda")): output = _autoencoder.decode(z) output_pil = tensor_to_pil(output[0]) os.unlink(tmp_path) return output_pil, "✅ Inference thành công!" except Exception as e: return None, f"❌ Lỗi: {str(e)}\n\n{traceback.format_exc()}" # ── Gradio UI ───────────────────────────────────────────────────────────────── def predict_wrapper(image_data, text_prompt, cfg_scale, sample_steps, skip_cross_atten): if image_data is None or image_data.get("background") is None: return None, "⚠️ Vui lòng upload ảnh đầu vào." if image_data.get("composite") is not None: final_image = image_data["composite"] if final_image.mode in ('RGBA', 'LA') or (final_image.mode == 'P' and 'transparency' in final_image.info): background = Image.new('RGB', final_image.size, (255, 255, 255)) background.paste(final_image, mask=final_image.split()[3]) pil_input = background else: pil_input = final_image.convert("RGB") else: pil_input = image_data["background"].convert("RGB") return run_inference(pil_input, text_prompt, float(cfg_scale), int(sample_steps), skip_cross_atten) with gr.Blocks(title="FlowInOne Demo - Visual Editing") as demo: gr.Markdown( """ # 🌊 FlowInOne Demo - Visual Editing **Unifying Multimodal Generation as Image-In Image-Out Flow Matching** *Lưu ý: Model ở độ phân giải 256px nên kết quả sẽ không sắc nét. Hãy dùng cọ đỏ bôi lên khu vực cần sửa.* """ ) with gr.Row(): with gr.Column(): input_img = gr.ImageEditor( type="pil", label="📥 Ảnh đầu vào (Tô cọ đỏ lên vùng cần sửa)", brush=gr.Brush(colors=["#FF0000", "#000000", "#FFFFFF", "#0000FF"]) ) text_prompt = gr.Textbox( label="Text Prompt (Ví dụ: remove the dog)", lines=2 ) with gr.Accordion("⚙️ Cài đặt nâng cao", open=False): cfg_scale = gr.Slider(minimum=1.0, maximum=15.0, value=7.5, step=0.5, label="CFG Scale") sample_steps = gr.Slider(minimum=10, maximum=50, value=20, step=5, label="Số bước sampling") skip_cross_atten = gr.Checkbox(value=False, label="Skip Cross Attention") run_btn = gr.Button("🚀 Chạy Inference", variant="primary") with gr.Column(): output_img = gr.Image(type="pil", label="📤 Ảnh đầu ra (256x256)") status_txt = gr.Textbox(label="Trạng thái", interactive=False) run_btn.click( fn=predict_wrapper, inputs=[input_img, text_prompt, cfg_scale, sample_steps, skip_cross_atten], outputs=[output_img, status_txt], ) if __name__ == "__main__": demo.launch()