Spaces:
Sleeping
Sleeping
| """ | |
| FlowInOne - HuggingFace Space Demo | |
| Unifying Multimodal Generation as Image-In Image-Out Flow Matching | |
| """ | |
| import os | |
| import sys | |
| import traceback | |
| import tempfile | |
| import numpy as np | |
| import torch | |
| import einops | |
| import gradio as gr | |
| from PIL import Image | |
| import spaces | |
| from huggingface_hub import hf_hub_download, snapshot_download | |
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) | |
| # ── Paths & Setup ──────────────────────────────────────────────────────────── | |
| IMAGE_SIZE = 256 | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| CONFIG_PATH = os.environ.get("CONFIG_PATH", "configs/flowinone_training_demo.py") | |
| print("Đang tải các models từ Hugging Face Hub...") | |
| NNET_PATH = hf_hub_download(repo_id="CSU-JPG/FlowInOne", filename="flowinone_256px.pth") | |
| JANUS_MODEL_PATH = snapshot_download(repo_id="deepseek-ai/Janus-Pro-1B") | |
| # ── Helpers ─────────────────────────────────────────────────────────────────── | |
| def unpreprocess(tensor): | |
| tensor = (tensor.clamp(-1, 1) + 1.0) / 2.0 | |
| return tensor | |
| def center_crop_arr(pil_image, image_size): | |
| while min(*pil_image.size) >= 2 * image_size: | |
| pil_image = pil_image.resize( | |
| tuple(x // 2 for x in pil_image.size), resample=Image.BOX | |
| ) | |
| scale = image_size / min(*pil_image.size) | |
| pil_image = pil_image.resize( | |
| tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC | |
| ) | |
| arr = np.array(pil_image) | |
| crop_y = (arr.shape[0] - image_size) // 2 | |
| crop_x = (arr.shape[1] - image_size) // 2 | |
| return arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size] | |
| def pil_to_tensor(pil_image, image_size, device): | |
| arr = center_crop_arr(pil_image.convert("RGB"), image_size) | |
| arr = (arr / 127.5 - 1.0).astype(np.float32) | |
| t = torch.from_numpy(einops.rearrange(arr, "h w c -> c h w")).to(device) | |
| return t.unsqueeze(0) | |
| def tensor_to_pil(arr): | |
| arr = unpreprocess(arr) | |
| arr = arr.detach().cpu().numpy() | |
| arr = (arr * 255).astype(np.uint8) | |
| arr = einops.rearrange(arr, "c h w -> h w c") | |
| return Image.fromarray(arr) | |
| # ── Model loading ───────────────────────────────────────────────────────────── | |
| print("Loading config...") | |
| import ml_collections | |
| spec = __import__("importlib").util.spec_from_file_location("cfg", CONFIG_PATH) | |
| cfg_module = __import__("importlib").util.module_from_spec(spec) | |
| spec.loader.exec_module(cfg_module) | |
| _config = cfg_module.get_config() | |
| print("Loading VAE...") | |
| import libs.autoencoder as autoencoder_lib | |
| vae_weight_path = hf_hub_download(repo_id="stabilityai/sd-vae-ft-mse-original", filename="vae-ft-mse-840000-ema-pruned.ckpt") | |
| fixed_vae_path = os.path.join(tempfile.gettempdir(), "vae_extracted_v2.pth") | |
| if not os.path.exists(fixed_vae_path): | |
| ckpt = torch.load(vae_weight_path, map_location="cpu") | |
| real_state_dict = ckpt["state_dict"] if "state_dict" in ckpt else ckpt | |
| for k in ["model_ema.decay", "model_ema.num_updates"]: | |
| real_state_dict.pop(k, None) | |
| torch.save(real_state_dict, fixed_vae_path) | |
| _config.autoencoder.pretrained_path = fixed_vae_path | |
| _autoencoder = autoencoder_lib.get_model(**_config.autoencoder).to(DEVICE).eval() | |
| print("Loading NNet...") | |
| import utils | |
| _nnet = utils.get_nnet(**_config.nnet) | |
| _nnet.load_state_dict(torch.load(NNET_PATH, map_location="cpu")) | |
| _nnet.to(DEVICE).eval() | |
| print("Loading Janus-Pro-1B...") | |
| from libs.janus.models import MultiModalityCausalLM, VLChatProcessor | |
| from transformers import AutoModelForCausalLM | |
| _vl_chat_processor = VLChatProcessor.from_pretrained(JANUS_MODEL_PATH) | |
| _vl_gpt = AutoModelForCausalLM.from_pretrained( | |
| JANUS_MODEL_PATH, trust_remote_code=True, use_safetensors=False | |
| ).half().to(DEVICE).eval() | |
| # ── Inference ───────────────────────────────────────────────────────────────── | |
| def run_inference(pil_input, text_prompt, cfg_scale, sample_steps, skip_cross_atten): | |
| try: | |
| from diffusion.flow_matching import ODEEulerFlowMatchingSolver | |
| import utils | |
| # 1. Tiền xử lý ảnh (chỉ lấy ảnh đã crop, không chèn chữ) | |
| input_tensor = pil_to_tensor(pil_input, IMAGE_SIZE, DEVICE) | |
| arr = center_crop_arr(pil_input.convert("RGB"), IMAGE_SIZE) | |
| cropped_pil = Image.fromarray(arr) | |
| # 2. Truyền thẳng Text Prompt vào VLM (Đúng chuẩn Multimodal) | |
| question = text_prompt.strip() if text_prompt else "" | |
| sft_format = _vl_chat_processor.apply_sft_template_for_multi_turn_prompts( | |
| conversations=[ | |
| {"role": "<|User|>", "content": f"<image_placeholder>\n{question}"}, | |
| {"role": "<|Assistant|>", "content": ""}, | |
| ], | |
| sft_format=_vl_chat_processor.sft_format, | |
| system_prompt=_vl_chat_processor.system_prompt, | |
| ) | |
| cached_input_ids = _vl_chat_processor.tokenizer.encode(sft_format) | |
| with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp: | |
| cropped_pil.save(tmp.name) | |
| tmp_path = tmp.name | |
| use_cross_atten_mask = torch.tensor([skip_cross_atten], dtype=torch.bool, device=DEVICE) | |
| contexts, token_mask = utils.get_input_image_embeddings_and_masks( | |
| batch_input_images=[tmp_path], | |
| vl_chat_processor=_vl_chat_processor, | |
| vl_gpt=_vl_gpt, | |
| device=DEVICE, | |
| question=question, # <--- Truyền câu hỏi vào hàm context generator | |
| num_image_tokens=576, | |
| output_tokens=576, | |
| accelerator=None, | |
| cached_input_ids=cached_input_ids, | |
| ) | |
| with torch.no_grad(): | |
| input_moments = _autoencoder(input_tensor, fn="encode_moments") | |
| input_latent = _autoencoder.sample(input_moments) | |
| z_gaussian = torch.randn(1, *_config.z_shape, device=DEVICE) | |
| with torch.cuda.amp.autocast(enabled=(DEVICE == "cuda")): | |
| z_x0, _, _ = _nnet( | |
| contexts, text_encoder=True, shape=z_gaussian.shape, | |
| mask=token_mask, use_cross_atten_mask=use_cross_atten_mask, | |
| ) | |
| z_init = z_x0.reshape(z_gaussian.shape) | |
| ode_solver = ODEEulerFlowMatchingSolver( | |
| _nnet, bdv_model_fn=None, step_size_type="step_in_dsigma", guidance_scale=cfg_scale, | |
| ) | |
| z, _ = ode_solver.sample( | |
| x_T=z_init, batch_size=1, sample_steps=sample_steps, | |
| unconditional_guidance_scale=cfg_scale, | |
| has_null_indicator=hasattr(_config.nnet.model_args, "cfg_indicator"), | |
| image_latent=input_latent, use_cross_atten_mask=use_cross_atten_mask, | |
| ) | |
| with torch.cuda.amp.autocast(enabled=(DEVICE == "cuda")): | |
| output = _autoencoder.decode(z) | |
| output_pil = tensor_to_pil(output[0]) | |
| os.unlink(tmp_path) | |
| return output_pil, "✅ Inference thành công!" | |
| except Exception as e: | |
| return None, f"❌ Lỗi: {str(e)}\n\n{traceback.format_exc()}" | |
| # ── Gradio UI ───────────────────────────────────────────────────────────────── | |
| def predict_wrapper(image_data, text_prompt, cfg_scale, sample_steps, skip_cross_atten): | |
| if image_data is None or image_data.get("background") is None: | |
| return None, "⚠️ Vui lòng upload ảnh đầu vào." | |
| if image_data.get("composite") is not None: | |
| final_image = image_data["composite"] | |
| if final_image.mode in ('RGBA', 'LA') or (final_image.mode == 'P' and 'transparency' in final_image.info): | |
| background = Image.new('RGB', final_image.size, (255, 255, 255)) | |
| background.paste(final_image, mask=final_image.split()[3]) | |
| pil_input = background | |
| else: | |
| pil_input = final_image.convert("RGB") | |
| else: | |
| pil_input = image_data["background"].convert("RGB") | |
| return run_inference(pil_input, text_prompt, float(cfg_scale), int(sample_steps), skip_cross_atten) | |
| with gr.Blocks(title="FlowInOne Demo - Visual Editing") as demo: | |
| gr.Markdown( | |
| """ | |
| # 🌊 FlowInOne Demo - Visual Editing | |
| **Unifying Multimodal Generation as Image-In Image-Out Flow Matching** | |
| *Lưu ý: Model ở độ phân giải 256px nên kết quả sẽ không sắc nét. Hãy dùng cọ đỏ bôi lên khu vực cần sửa.* | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_img = gr.ImageEditor( | |
| type="pil", | |
| label="📥 Ảnh đầu vào (Tô cọ đỏ lên vùng cần sửa)", | |
| brush=gr.Brush(colors=["#FF0000", "#000000", "#FFFFFF", "#0000FF"]) | |
| ) | |
| text_prompt = gr.Textbox( | |
| label="Text Prompt (Ví dụ: remove the dog)", lines=2 | |
| ) | |
| with gr.Accordion("⚙️ Cài đặt nâng cao", open=False): | |
| cfg_scale = gr.Slider(minimum=1.0, maximum=15.0, value=7.5, step=0.5, label="CFG Scale") | |
| sample_steps = gr.Slider(minimum=10, maximum=50, value=20, step=5, label="Số bước sampling") | |
| skip_cross_atten = gr.Checkbox(value=False, label="Skip Cross Attention") | |
| run_btn = gr.Button("🚀 Chạy Inference", variant="primary") | |
| with gr.Column(): | |
| output_img = gr.Image(type="pil", label="📤 Ảnh đầu ra (256x256)") | |
| status_txt = gr.Textbox(label="Trạng thái", interactive=False) | |
| run_btn.click( | |
| fn=predict_wrapper, | |
| inputs=[input_img, text_prompt, cfg_scale, sample_steps, skip_cross_atten], | |
| outputs=[output_img, status_txt], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |