Spaces:
Build error
Build error
| """ | |
| Gradio demo for text customization with Calligrapher (the reference is uploaded by the user). | |
| """ | |
| import os | |
| import json | |
| import gradio as gr | |
| import spaces | |
| import numpy as np | |
| from datetime import datetime | |
| import torch | |
| from PIL import Image | |
| from pipeline_calligrapher import CalligrapherPipeline | |
| from models.calligrapher import Calligrapher | |
| from models.transformer_flux_inpainting import FluxTransformer2DModel | |
| from utils import process_gradio_source, get_bbox_from_mask, crop_image_from_bb, \ | |
| resize_img_and_pad, generate_context_reference_image | |
| from huggingface_hub import snapshot_download | |
| # Global settings. | |
| with open(os.path.join(os.path.dirname(__file__), 'path_dict.json'), 'r') as f: | |
| path_dict = json.load(f) | |
| SAVE_DIR = path_dict['gradio_save_dir'] | |
| os.environ["GRADIO_TEMP_DIR"] = path_dict['gradio_temp_dir'] | |
| os.environ['TMPDIR'] = path_dict['gradio_temp_dir'] | |
| # Function of loading pre-trained models. | |
| def load_models(): | |
| base_model_path = snapshot_download("black-forest-labs/FLUX.1-Fill-dev") | |
| image_encoder_path = snapshot_download("google/siglip-so400m-patch14-384") | |
| calligrapher_path = snapshot_download("Calligrapher2025/Calligrapher",allow_patterns="*.bin") | |
| calligrapher_path+="/calligrapher.bin" | |
| transformer = FluxTransformer2DModel.from_pretrained(base_model_path, subfolder="transformer", | |
| torch_dtype=torch.bfloat16) | |
| pipe = CalligrapherPipeline.from_pretrained(base_model_path, transformer=transformer, | |
| torch_dtype=torch.bfloat16).to("cuda") | |
| model = Calligrapher(pipe, image_encoder_path, calligrapher_path, device="cuda", num_tokens=128) | |
| return model | |
| model = load_models() | |
| def process_and_generate(editor_component, reference_image, prompt, height, width, | |
| scale, steps=50, seed=42, use_context=True, num_images=1): | |
| print('Begin processing!') | |
| # Job directory. | |
| job_name = datetime.now().strftime("%Y-%m-%d-%H-%M-%S") | |
| job_dir = os.path.join(SAVE_DIR, job_name) | |
| os.makedirs(job_dir, exist_ok=True) | |
| # Get source, mask, and cropped images from gr.ImageEditor. | |
| source_image, mask_image, cropped_image = process_gradio_source(editor_component) | |
| source_image.save(os.path.join(job_dir, 'source_image.png')) | |
| mask_image.save(os.path.join(job_dir, 'mask_image.png')) | |
| cropped_image.save(os.path.join(job_dir, 'cropped_image.png')) | |
| # Resize source and mask. | |
| source_image = source_image.resize((width, height)) | |
| mask_image = mask_image.resize((width, height), Image.NEAREST) | |
| mask_np = np.array(mask_image) | |
| mask_np[mask_np > 0] = 255 | |
| mask_image = Image.fromarray(mask_np.astype(np.uint8)) | |
| if reference_image is None: | |
| # If self-inpaint (no input ref): (1) get bounding box from the mask and (2) perform cropping to get the ref image. | |
| tl, br = get_bbox_from_mask(mask_image) | |
| # Convert irregularly shaped masks into rectangles. | |
| reference_image = crop_image_from_bb(source_image, tl, br) | |
| # Raw reference image before resizing. | |
| reference_image.save(os.path.join(job_dir, 'reference_image_raw.png')) | |
| reference_image_to_encoder = resize_img_and_pad(reference_image, target_size=(512, 512)) | |
| reference_image_to_encoder.save(os.path.join(job_dir, 'reference_to_encoder.png')) | |
| reference_context = generate_context_reference_image(reference_image, width) | |
| if use_context: | |
| # Concat the context on the top of the input masked image in the pixel space. | |
| source_with_context = Image.new(source_image.mode, (width, reference_context.size[1] + height)) | |
| source_with_context.paste(reference_context, (0, 0)) | |
| source_with_context.paste(source_image, (0, reference_context.size[1])) | |
| # Concat the zero mask on the top of the mask image. | |
| mask_with_context = Image.new(mask_image.mode, | |
| (mask_image.size[0], reference_context.size[1] + mask_image.size[0]), color=0) | |
| mask_with_context.paste(mask_image, (0, reference_context.size[1])) | |
| source_image = source_with_context | |
| mask_image = mask_with_context | |
| all_generated_images = [] | |
| for i in range(num_images): | |
| res = model.generate( | |
| image=source_image, | |
| mask_image=mask_image, | |
| ref_image=reference_image_to_encoder, | |
| prompt=prompt, | |
| scale=scale, | |
| num_inference_steps=steps, | |
| width=source_image.size[0], | |
| height=source_image.size[1], | |
| seed=seed + i, | |
| )[0] | |
| if use_context: | |
| res_vis = res.crop((0, reference_context.size[1], res.width, res.height)) # remove context | |
| mask_vis = mask_image.crop( | |
| (0, reference_context.size[1], mask_image.width, mask_image.height)) # remove context mask | |
| else: | |
| res_vis = res | |
| mask_vis = mask_image | |
| res_vis.save(os.path.join(job_dir, f'result_{i}.png')) | |
| all_generated_images.append((res_vis, f"Generating {i + 1} (Seed: {seed + i})")) | |
| return mask_vis, reference_image_to_encoder, all_generated_images | |
| # Main gradio codes. | |
| with gr.Blocks(theme="default", css=".image-editor img {max-width: 70%; height: 70%;}") as demo: | |
| gr.Markdown( | |
| """ | |
| # ๐๏ธ Calligrapher: Freestyle Text Image Customization | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| gr.Markdown("### ๐จ Image Editing Panel") | |
| editor_component = gr.ImageEditor( | |
| label="Upload or Draw", | |
| type="pil", | |
| brush=gr.Brush(colors=["#FFFFFF"], default_size=30, color_mode="fixed"), | |
| layers=True, | |
| interactive=True, | |
| ) | |
| gr.Markdown("### ๐ค Output Result") | |
| gallery = gr.Gallery(label="๐ผ๏ธ Result Gallery") | |
| gr.Markdown( | |
| """<br> | |
| ### โจUser Tips: | |
| 1. **Speed vs Quality Trade-off.** Use fewer steps (e.g., 10-step which takes ~4s/image on a single A6000 GPU) for faster generation, but quality may be lower. | |
| 2. **Inpaint Position Freedom.** Inpainting positions are flexible - they don't necessarily need to match the original text locations in the input image. | |
| 3. **Iterative Editing.** Drag outputs from the gallery to the Image Editing Panel (clean the Editing Panel first) for quick refinements. | |
| 4. **Mask Optimization.** Adjust mask size/aspect ratio to match your desired content. The model tends to fill the masks, and harmonizes the generation with background in terms of color and lighting. | |
| 5. **Reference Image Tip.** White-background references improve style consistency - the encoder also considers background context of the given reference image. | |
| 6. **Resolution Balance.** Very high-resolution generation sometimes triggers spelling errors. 512/768px are recommended considering the model is trained under the resolution of 512. | |
| """ | |
| ) | |
| with gr.Column(scale=1): | |
| gr.Markdown("### โ๏ธSettings") | |
| reference_image = gr.Image( | |
| label="๐งฉ Reference Image (skip this if self-reference)", | |
| sources=["upload"], | |
| type="pil", | |
| ) | |
| prompt = gr.Textbox( | |
| label="๐ Prompt", | |
| placeholder="The text is 'Image'...", | |
| value="The text is 'Image'." | |
| ) | |
| with gr.Accordion("๐ง Additional Settings", open=True): | |
| with gr.Row(): | |
| height = gr.Number(label="Height", value=512, precision=0) | |
| width = gr.Number(label="Width", value=512, precision=0) | |
| scale = gr.Slider(0.0, 2.0, 1.0, step=0.1, value=1.0, label="๐๏ธ Strength") | |
| steps = gr.Slider(1, 100, 50, step=1, label="๐ Steps") | |
| with gr.Row(): | |
| seed = gr.Number(label="๐ฒ Seed", value=56, precision=0) | |
| use_context = gr.Checkbox(value=True, label="๐ Use Context", interactive=True) | |
| num_images = gr.Slider(1, 16, 2, step=1, label="๐ผ๏ธ Sample Amount") | |
| run_btn = gr.Button("๐ Run", variant="primary") | |
| mask_output = gr.Image(label="๐ฉ Mask Demo") | |
| reference_demo = gr.Image(label="๐งฉ Reference Demo") | |
| # Run button event. | |
| run_btn.click( | |
| fn=process_and_generate, | |
| inputs=[ | |
| editor_component, | |
| reference_image, | |
| prompt, | |
| height, | |
| width, | |
| scale, | |
| steps, | |
| seed, | |
| use_context, | |
| num_images | |
| ], | |
| outputs=[ | |
| mask_output, | |
| reference_demo, | |
| gallery | |
| ] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(share=True) | |