FLowInOne_demo / app.py
oedevs's picture
Update app.py
252ccd9 verified
"""
FlowInOne - HuggingFace Space Demo
Unifying Multimodal Generation as Image-In Image-Out Flow Matching
"""
import os
import sys
import traceback
import tempfile
import numpy as np
import torch
import einops
import gradio as gr
from PIL import Image
import spaces
from huggingface_hub import hf_hub_download, snapshot_download
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
# ── Paths & Setup ────────────────────────────────────────────────────────────
IMAGE_SIZE = 256
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
CONFIG_PATH = os.environ.get("CONFIG_PATH", "configs/flowinone_training_demo.py")
print("Đang tải các models từ Hugging Face Hub...")
NNET_PATH = hf_hub_download(repo_id="CSU-JPG/FlowInOne", filename="flowinone_256px.pth")
JANUS_MODEL_PATH = snapshot_download(repo_id="deepseek-ai/Janus-Pro-1B")
# ── Helpers ───────────────────────────────────────────────────────────────────
def unpreprocess(tensor):
tensor = (tensor.clamp(-1, 1) + 1.0) / 2.0
return tensor
def center_crop_arr(pil_image, image_size):
while min(*pil_image.size) >= 2 * image_size:
pil_image = pil_image.resize(
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
)
scale = image_size / min(*pil_image.size)
pil_image = pil_image.resize(
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
)
arr = np.array(pil_image)
crop_y = (arr.shape[0] - image_size) // 2
crop_x = (arr.shape[1] - image_size) // 2
return arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]
def pil_to_tensor(pil_image, image_size, device):
arr = center_crop_arr(pil_image.convert("RGB"), image_size)
arr = (arr / 127.5 - 1.0).astype(np.float32)
t = torch.from_numpy(einops.rearrange(arr, "h w c -> c h w")).to(device)
return t.unsqueeze(0)
def tensor_to_pil(arr):
arr = unpreprocess(arr)
arr = arr.detach().cpu().numpy()
arr = (arr * 255).astype(np.uint8)
arr = einops.rearrange(arr, "c h w -> h w c")
return Image.fromarray(arr)
# ── Model loading ─────────────────────────────────────────────────────────────
print("Loading config...")
import ml_collections
spec = __import__("importlib").util.spec_from_file_location("cfg", CONFIG_PATH)
cfg_module = __import__("importlib").util.module_from_spec(spec)
spec.loader.exec_module(cfg_module)
_config = cfg_module.get_config()
print("Loading VAE...")
import libs.autoencoder as autoencoder_lib
vae_weight_path = hf_hub_download(repo_id="stabilityai/sd-vae-ft-mse-original", filename="vae-ft-mse-840000-ema-pruned.ckpt")
fixed_vae_path = os.path.join(tempfile.gettempdir(), "vae_extracted_v2.pth")
if not os.path.exists(fixed_vae_path):
ckpt = torch.load(vae_weight_path, map_location="cpu")
real_state_dict = ckpt["state_dict"] if "state_dict" in ckpt else ckpt
for k in ["model_ema.decay", "model_ema.num_updates"]:
real_state_dict.pop(k, None)
torch.save(real_state_dict, fixed_vae_path)
_config.autoencoder.pretrained_path = fixed_vae_path
_autoencoder = autoencoder_lib.get_model(**_config.autoencoder).to(DEVICE).eval()
print("Loading NNet...")
import utils
_nnet = utils.get_nnet(**_config.nnet)
_nnet.load_state_dict(torch.load(NNET_PATH, map_location="cpu"))
_nnet.to(DEVICE).eval()
print("Loading Janus-Pro-1B...")
from libs.janus.models import MultiModalityCausalLM, VLChatProcessor
from transformers import AutoModelForCausalLM
_vl_chat_processor = VLChatProcessor.from_pretrained(JANUS_MODEL_PATH)
_vl_gpt = AutoModelForCausalLM.from_pretrained(
JANUS_MODEL_PATH, trust_remote_code=True, use_safetensors=False
).half().to(DEVICE).eval()
# ── Inference ─────────────────────────────────────────────────────────────────
@spaces.GPU
def run_inference(pil_input, text_prompt, cfg_scale, sample_steps, skip_cross_atten):
try:
from diffusion.flow_matching import ODEEulerFlowMatchingSolver
import utils
# 1. Tiền xử lý ảnh (chỉ lấy ảnh đã crop, không chèn chữ)
input_tensor = pil_to_tensor(pil_input, IMAGE_SIZE, DEVICE)
arr = center_crop_arr(pil_input.convert("RGB"), IMAGE_SIZE)
cropped_pil = Image.fromarray(arr)
# 2. Truyền thẳng Text Prompt vào VLM (Đúng chuẩn Multimodal)
question = text_prompt.strip() if text_prompt else ""
sft_format = _vl_chat_processor.apply_sft_template_for_multi_turn_prompts(
conversations=[
{"role": "<|User|>", "content": f"<image_placeholder>\n{question}"},
{"role": "<|Assistant|>", "content": ""},
],
sft_format=_vl_chat_processor.sft_format,
system_prompt=_vl_chat_processor.system_prompt,
)
cached_input_ids = _vl_chat_processor.tokenizer.encode(sft_format)
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
cropped_pil.save(tmp.name)
tmp_path = tmp.name
use_cross_atten_mask = torch.tensor([skip_cross_atten], dtype=torch.bool, device=DEVICE)
contexts, token_mask = utils.get_input_image_embeddings_and_masks(
batch_input_images=[tmp_path],
vl_chat_processor=_vl_chat_processor,
vl_gpt=_vl_gpt,
device=DEVICE,
question=question, # <--- Truyền câu hỏi vào hàm context generator
num_image_tokens=576,
output_tokens=576,
accelerator=None,
cached_input_ids=cached_input_ids,
)
with torch.no_grad():
input_moments = _autoencoder(input_tensor, fn="encode_moments")
input_latent = _autoencoder.sample(input_moments)
z_gaussian = torch.randn(1, *_config.z_shape, device=DEVICE)
with torch.cuda.amp.autocast(enabled=(DEVICE == "cuda")):
z_x0, _, _ = _nnet(
contexts, text_encoder=True, shape=z_gaussian.shape,
mask=token_mask, use_cross_atten_mask=use_cross_atten_mask,
)
z_init = z_x0.reshape(z_gaussian.shape)
ode_solver = ODEEulerFlowMatchingSolver(
_nnet, bdv_model_fn=None, step_size_type="step_in_dsigma", guidance_scale=cfg_scale,
)
z, _ = ode_solver.sample(
x_T=z_init, batch_size=1, sample_steps=sample_steps,
unconditional_guidance_scale=cfg_scale,
has_null_indicator=hasattr(_config.nnet.model_args, "cfg_indicator"),
image_latent=input_latent, use_cross_atten_mask=use_cross_atten_mask,
)
with torch.cuda.amp.autocast(enabled=(DEVICE == "cuda")):
output = _autoencoder.decode(z)
output_pil = tensor_to_pil(output[0])
os.unlink(tmp_path)
return output_pil, "✅ Inference thành công!"
except Exception as e:
return None, f"❌ Lỗi: {str(e)}\n\n{traceback.format_exc()}"
# ── Gradio UI ─────────────────────────────────────────────────────────────────
def predict_wrapper(image_data, text_prompt, cfg_scale, sample_steps, skip_cross_atten):
if image_data is None or image_data.get("background") is None:
return None, "⚠️ Vui lòng upload ảnh đầu vào."
if image_data.get("composite") is not None:
final_image = image_data["composite"]
if final_image.mode in ('RGBA', 'LA') or (final_image.mode == 'P' and 'transparency' in final_image.info):
background = Image.new('RGB', final_image.size, (255, 255, 255))
background.paste(final_image, mask=final_image.split()[3])
pil_input = background
else:
pil_input = final_image.convert("RGB")
else:
pil_input = image_data["background"].convert("RGB")
return run_inference(pil_input, text_prompt, float(cfg_scale), int(sample_steps), skip_cross_atten)
with gr.Blocks(title="FlowInOne Demo - Visual Editing") as demo:
gr.Markdown(
"""
# 🌊 FlowInOne Demo - Visual Editing
**Unifying Multimodal Generation as Image-In Image-Out Flow Matching**
*Lưu ý: Model ở độ phân giải 256px nên kết quả sẽ không sắc nét. Hãy dùng cọ đỏ bôi lên khu vực cần sửa.*
"""
)
with gr.Row():
with gr.Column():
input_img = gr.ImageEditor(
type="pil",
label="📥 Ảnh đầu vào (Tô cọ đỏ lên vùng cần sửa)",
brush=gr.Brush(colors=["#FF0000", "#000000", "#FFFFFF", "#0000FF"])
)
text_prompt = gr.Textbox(
label="Text Prompt (Ví dụ: remove the dog)", lines=2
)
with gr.Accordion("⚙️ Cài đặt nâng cao", open=False):
cfg_scale = gr.Slider(minimum=1.0, maximum=15.0, value=7.5, step=0.5, label="CFG Scale")
sample_steps = gr.Slider(minimum=10, maximum=50, value=20, step=5, label="Số bước sampling")
skip_cross_atten = gr.Checkbox(value=False, label="Skip Cross Attention")
run_btn = gr.Button("🚀 Chạy Inference", variant="primary")
with gr.Column():
output_img = gr.Image(type="pil", label="📤 Ảnh đầu ra (256x256)")
status_txt = gr.Textbox(label="Trạng thái", interactive=False)
run_btn.click(
fn=predict_wrapper,
inputs=[input_img, text_prompt, cfg_scale, sample_steps, skip_cross_atten],
outputs=[output_img, status_txt],
)
if __name__ == "__main__":
demo.launch()