import os import sys sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import argparse import time import math from PIL import Image import torch from transformers import AutoTokenizer from model import LLaDAForMultiModalGeneration from utils.generation_utils import setup_seed from utils.image_utils import ( preprocess_image, decode_vq_to_image, calculate_vq_params, generate_crop_size_list, var_center_crop, add_break_line, encode_img_with_breaks, encode_img_with_paint ) from generators.parallel_generator import generate_ti2ti from utils.prompt_utils import generate_text_image_to_text_image_prompt SPECIAL_TOKENS = { "mask_token": 126336, "newline_token": 126084, "image_token_offset": 126356, "answer_start": 126354, "answer_end": 126355, "boi": 126349, "eoi": 126350, "uncondition": 126351 } SYSTEM_PROMPT = ( "Generate an image applying the following editing instruction based on the original image." ) def cosine_schedule(t): return torch.cos(t * math.pi / 2) def main(): parser = argparse.ArgumentParser(description="Text+Image to Text+Image inference (TI2TI)") parser.add_argument("--checkpoint", type=str, required=True, help="Fine-tuned checkpoint path") parser.add_argument("--prompt", type=str, required=True, help="Text prompt for editing") parser.add_argument("--image_path", type=str, required=True, help="Input image path") parser.add_argument("--height", type=int, default=512, help="Output image height") parser.add_argument("--width", type=int, default=512, help="Output image width") parser.add_argument("--timesteps", type=int, default=64, help="Number of diffusion timesteps") parser.add_argument("--text_steps", type=int, default=256, help="Number of text generation steps") parser.add_argument("--text_gen_length", type=int, default=256, help="Maximum text generation length") parser.add_argument("--text_block_length", type=int, default=32, help="Text generation block length") parser.add_argument("--cfg_scale", type=float, default=2.5, help="CFG scale for text") parser.add_argument("--cfg_img", type=float, default=4.0, help="CFG scale for image") parser.add_argument("--temperature", type=float, default=1.0, help="Sampling temperature") parser.add_argument("--text_temperature", type=float, default=0.7, help="Text generation temperature") parser.add_argument("--seed", type=int, default=0, help="Random seed") parser.add_argument("--vae_ckpt", type=str, required=True, help="VAE checkpoint path") parser.add_argument("--output_dir", type=str, default="results_ti2ti", help="Output directory") parser.add_argument("--remasking", type=str, default="low_confidence", choices=["low_confidence", "random"], help="Remasking strategy") parser.add_argument("--painting_mode", type=str, default=None, help="If set, use painting-mode encoding") parser.add_argument("--mask_h_ratio", type=float, default=0.5, help="mask height ratio for painting mode") parser.add_argument("--mask_w_ratio", type=float, default=0.5, help="mask width ratio for painting mode") parser.add_argument("--debug_tokens", action="store_true", help="Print token debug info to verify sequence layout") args = parser.parse_args() MASK = SPECIAL_TOKENS["mask_token"] NEW_LINE = SPECIAL_TOKENS["newline_token"] BOA = SPECIAL_TOKENS["answer_start"] EOA = SPECIAL_TOKENS["answer_end"] BOI = SPECIAL_TOKENS["boi"] EOI = SPECIAL_TOKENS["eoi"] if args.seed != 0: setup_seed(args.seed) os.makedirs(args.output_dir, exist_ok=True) device = 'cuda' if torch.cuda.is_available() else 'cpu' print(f"Loading model from {args.checkpoint}...") tokenizer = AutoTokenizer.from_pretrained(args.checkpoint, trust_remote_code=True) model = LLaDAForMultiModalGeneration.from_pretrained( args.checkpoint, torch_dtype=torch.bfloat16, device_map="auto", ) config = model.config text_vocab_size = getattr(config, 'text_vocab_size', 126356) codebook_size = getattr(config, 'codebook_size', 8192) print(f"Vocabulary config: text_vocab_size={text_vocab_size}, codebook_size={codebook_size}") print(f"Loading VQ-VAE from {args.vae_ckpt}...") from diffusers import VQModel vqvae = VQModel.from_pretrained(args.vae_ckpt, subfolder="vqvae").to(device) vae_scale = 2 ** (len(vqvae.config.block_out_channels) - 1) prompt_text = args.prompt input_image_path = args.image_path print(f"\n{'='*80}") print(f"TI2TI Generation") print(f"{'='*80}") print(f"Input image: {input_image_path}") print(f"Prompt: {prompt_text}") print(f"Output size: {args.height}x{args.width}") print(f"{'='*80}\n") input_prompt, uncon_text = generate_text_image_to_text_image_prompt( prompt_text, SYSTEM_PROMPT ) print("Conditioning prompt:\n", input_prompt) if args.debug_tokens: print("Unconditional text prompt (first 200 chars):", uncon_text[:200]) prompt_ids = tokenizer(input_prompt)["input_ids"] uncon_text_ids = tokenizer(uncon_text)["input_ids"] img = Image.open(input_image_path).convert("RGB") crop_size_list = generate_crop_size_list((512 // 32) ** 2, 32) img = var_center_crop(img, crop_size_list=crop_size_list) input_image_width, input_image_height = img.size print("Encoding input image for conditioning...") input_img_token = encode_img_with_breaks(img, vqvae) con_input_list = prompt_ids[:-1] + input_img_token + prompt_ids[-1:] uncon_input_text = uncon_text_ids[:-1] + input_img_token + uncon_text_ids[-1:] uncon_input_image = prompt_ids output_image_height = args.height output_image_width = args.width seq_len, newline_every, token_grid_height, token_grid_width = calculate_vq_params( output_image_height, output_image_width, vae_scale ) text_mask_tokens = [MASK] * args.text_gen_length if args.painting_mode: img_mask_token, img_vis = encode_img_with_paint( img, vqvae=vqvae, mask_h_ratio=args.mask_h_ratio, mask_w_ratio=args.mask_w_ratio, mask_mode=args.painting_mode ) else: img_mask_token = add_break_line([MASK] * seq_len, token_grid_height, token_grid_width, new_number=NEW_LINE) end_token_ids = tokenizer("", add_special_tokens=False).input_ids pred_token = [BOA] + [BOI] + img_mask_token + [EOI] + text_mask_tokens + end_token_ids code_start = len(con_input_list) image_start = len(con_input_list) + 2 image_end = image_start + len(img_mask_token) text_start = image_end + 1 text_end = text_start + args.text_gen_length full_input_ids = con_input_list + pred_token con_input = torch.tensor(full_input_ids, device=device).unsqueeze(0) uncon_input_text = torch.tensor(uncon_input_text, device=device).unsqueeze(0) uncon_input_image = torch.tensor(uncon_input_image, device=device).unsqueeze(0) start_time = time.time() if args.seed != 0: generator = torch.Generator(device=device).manual_seed(args.seed) else: generator = None output_tokens, generated_text = generate_ti2ti( model=model, input_ids=con_input, text_start=text_start, text_end=text_end, image_start=image_start, seq_len=seq_len, newline_every=newline_every, text_steps=args.text_steps, text_gen_length=args.text_gen_length, text_block_length=args.text_block_length, timesteps=args.timesteps, temperature=args.temperature, text_temperature=args.text_temperature, cfg_scale=args.cfg_scale, cfg_img=args.cfg_img, uncon_text=uncon_input_text, uncon_image=uncon_input_image, tokenizer=tokenizer, remasking=args.remasking, noise_schedule=cosine_schedule, generator=generator, text_vocab_size=text_vocab_size, codebook_size=codebook_size, ) end_time = time.time() elapsed_time = end_time - start_time print(f"\n{'='*80}") print(f"Generated thinking/text output:") print(f"{'='*80}") print(generated_text) print(f"{'='*80}\n") print(f"Converting {len(output_tokens)} VQ tokens to tensor...") output_tokens_tensor = torch.tensor(output_tokens, dtype=torch.long, device=device).unsqueeze(0) print(f"VQ tokens range: [{min(output_tokens)}, {max(output_tokens)}]") words = (prompt_text or "").split() filename_words = words[:10] if len(words) > 10 else words filename = "_".join(filename_words) filename = "".join(c for c in filename if c.isalnum() or c in ('_', '-')) filename = f"{filename}_{output_image_height}x{output_image_width}_t{args.timesteps}_cfg{args.cfg_scale}_ti2ti.png" save_path = os.path.join(args.output_dir, filename) print("Decoding image...") out_img = decode_vq_to_image( output_tokens_tensor, save_path, vae_ckpt=args.vae_ckpt, image_height=output_image_height, image_width=output_image_width, vqvae=vqvae ) w1, h1 = img.size w2, h2 = out_img.size canvas = Image.new("RGB", (w1 + w2, max(h1, h2)), "white") canvas.paste(img, (0, 0)) canvas.paste(out_img, (w1, 0)) concat_path = save_path.replace(".png", "_concat.png") canvas.save(concat_path) text_path = save_path.replace(".png", "_thinking.txt") with open(text_path, "w", encoding="utf-8") as f: f.write(f"{generated_text}\n") print(f"\n[✓] Image saved to: {concat_path}") print(f"[✓] Text saved to: {text_path}") print(f"[✓] Total time: {elapsed_time:.2f}s") if __name__ == '__main__': main()