| |
| """ |
| Gradio Web UI for Infinity-2B GGUF Image Generation |
| Provides an easy-to-use interface for generating images with the quantized model |
| """ |
|
|
| import os |
| import sys |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
| |
| SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) |
| INFINITY_PATH = os.path.join(SCRIPT_DIR, 'Infinity') |
| if os.path.exists(INFINITY_PATH): |
| sys.path.insert(0, INFINITY_PATH) |
|
|
| import time |
| import argparse |
| import torch |
| import numpy as np |
| import gradio as gr |
| from PIL import Image |
| from datetime import datetime |
|
|
| |
| from generate_image_2b_q8_gguf import ( |
| load_t5_tokenizer_from_gguf, |
| load_t5_encoder_from_gguf, |
| load_infinity_from_gguf, |
| load_vae, |
| generate_image |
| ) |
|
|
| from infinity.utils.dynamic_resolution import dynamic_resolution_h_w, h_div_w_templates |
|
|
|
|
| |
| class ModelCache: |
| def __init__(self): |
| self.vae = None |
| self.text_tokenizer = None |
| self.text_encoder = None |
| self.infinity_model = None |
| self.device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| self.loaded = False |
|
|
|
|
| model_cache = ModelCache() |
|
|
|
|
| def load_models(infinity_gguf_path, t5_gguf_path, vae_path, pn='1M', progress=gr.Progress()): |
| """ |
| Load all models with progress tracking |
| """ |
| global model_cache |
|
|
| if model_cache.loaded: |
| return "✓ Models already loaded!" |
|
|
| progress(0, desc="Loading VAE...") |
| model_cache.vae = load_vae(vae_path, vae_type=32, device=model_cache.device) |
|
|
| progress(0.25, desc="Loading T5 Tokenizer...") |
| model_cache.text_tokenizer = load_t5_tokenizer_from_gguf(t5_gguf_path) |
|
|
| progress(0.5, desc="Loading T5 Encoder (on CPU)...") |
| model_cache.text_encoder = load_t5_encoder_from_gguf(t5_gguf_path, device='cpu') |
|
|
| progress(0.75, desc="Loading Infinity-2B from GGUF...") |
| model_cache.infinity_model = load_infinity_from_gguf( |
| infinity_gguf_path, |
| vae=model_cache.vae, |
| device=model_cache.device, |
| model_type='infinity_2b', |
| text_channels=2048, |
| pn=pn |
| ) |
|
|
| model_cache.loaded = True |
| progress(1.0, desc="Complete!") |
|
|
| return "✓ All models loaded successfully!" |
|
|
|
|
| def generate_image_gradio( |
| prompt, |
| cfg_scale, |
| tau, |
| seed, |
| aspect_ratio, |
| pn, |
| use_random_seed, |
| progress=gr.Progress() |
| ): |
| """ |
| Generate image with Gradio progress tracking |
| """ |
| global model_cache |
|
|
| if not model_cache.loaded: |
| return None, "❌ Please load models first!" |
|
|
| try: |
| |
| if use_random_seed: |
| seed = np.random.randint(0, 2**31 - 1) |
|
|
| |
| if seed is not None: |
| torch.manual_seed(seed) |
| np.random.seed(seed) |
| if model_cache.device == 'cuda': |
| torch.cuda.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
|
|
| |
| h_div_w_template = h_div_w_templates[ |
| np.argmin(np.abs(h_div_w_templates - aspect_ratio)) |
| ] |
| scale_schedule = dynamic_resolution_h_w[h_div_w_template][pn]['scales'] |
| scale_schedule = [(1, h, w) for (_, h, w) in scale_schedule] |
|
|
| progress(0.1, desc="Encoding prompt...") |
| start_time = time.time() |
|
|
| progress(0.3, desc="Generating image (this may take a while)...") |
|
|
| |
| img_np = generate_image( |
| model_cache.infinity_model, |
| model_cache.vae, |
| model_cache.text_tokenizer, |
| model_cache.text_encoder, |
| prompt, |
| cfg_scale=cfg_scale, |
| tau=tau, |
| seed=seed, |
| scale_schedule=scale_schedule, |
| vae_type=32, |
| device=model_cache.device |
| ) |
|
|
| progress(0.9, desc="Converting to PIL Image...") |
|
|
| |
| img_np = img_np.cpu().numpy() |
| |
| img_rgb = img_np[:, :, ::-1] |
| pil_image = Image.fromarray(img_rgb.astype(np.uint8)) |
|
|
| elapsed_time = time.time() - start_time |
|
|
| |
| h, w = img_np.shape[:2] |
|
|
| info = f"""✓ Generation complete! |
| |
| **Time**: {elapsed_time:.2f}s |
| **Resolution**: {w}x{h} |
| **Seed**: {seed} |
| **CFG Scale**: {cfg_scale} |
| **Tau**: {tau} |
| **Aspect Ratio**: {aspect_ratio:.2f} |
| **PN**: {pn}""" |
|
|
| progress(1.0, desc="Done!") |
|
|
| return pil_image, info |
|
|
| except Exception as e: |
| import traceback |
| error_msg = f"❌ Error during generation:\n{str(e)}\n\n{traceback.format_exc()}" |
| return None, error_msg |
|
|
|
|
| def create_ui(): |
| """ |
| Create Gradio UI |
| """ |
| |
| with gr.Blocks(title="Infinity-2B GGUF Generator") as demo: |
| gr.Markdown("# 🎨 Infinity-2B GGUF Image Generator") |
|
|
| |
| with gr.Row(): |
| infinity_gguf = gr.Textbox( |
| label="Infinity-2B GGUF", |
| value="infinity_2b_reg_Q8_0.gguf", |
| scale=2 |
| ) |
|
|
| t5_gguf = gr.Textbox( |
| label="T5 GGUF", |
| value="flan-t5-xl-encoder-Q8_0.gguf", |
| scale=2 |
| ) |
|
|
| vae_path = gr.Textbox( |
| label="VAE Checkpoint", |
| value="Infinity/infinity_vae_d32_reg.pth", |
| scale=2 |
| ) |
|
|
| pn_load = gr.Dropdown( |
| label="Resolution Preset", |
| choices=['0.06M', '0.25M', '1M'], |
| value='1M', |
| scale=1 |
| ) |
|
|
| load_btn = gr.Button("🚀 Load Models", variant="primary", scale=1) |
|
|
| load_status = gr.Textbox(label="Status", interactive=False, show_label=False) |
|
|
| |
| with gr.Row(): |
| |
| with gr.Column(scale=1): |
| gr.Markdown("### Generation Settings") |
|
|
| prompt = gr.Textbox( |
| label="Prompt", |
| placeholder="Describe the image you want to generate...", |
| value="an astronaut riding a horse on the moon", |
| lines=3 |
| ) |
|
|
| with gr.Row(): |
| cfg_scale = gr.Slider( |
| minimum=1.0, |
| maximum=10.0, |
| value=3.0, |
| step=0.5, |
| label="CFG Scale", |
| info="Higher = stronger prompt adherence" |
| ) |
|
|
| tau = gr.Slider( |
| minimum=0.1, |
| maximum=1.0, |
| value=0.5, |
| step=0.05, |
| label="Tau (Temperature)", |
| info="Lower = more deterministic" |
| ) |
|
|
| with gr.Row(): |
| aspect_ratio = gr.Slider( |
| minimum=0.5, |
| maximum=2.0, |
| value=1.0, |
| step=0.1, |
| label="Aspect Ratio (H/W)", |
| info="1.0 = square, >1.0 = portrait, <1.0 = landscape" |
| ) |
|
|
| pn = gr.Dropdown( |
| label="Resolution Preset", |
| choices=['0.06M', '0.25M', '1M'], |
| value='1M', |
| info="Higher = better quality but slower" |
| ) |
|
|
| with gr.Row(): |
| seed = gr.Number( |
| label="Seed", |
| value=42, |
| precision=0, |
| info="For reproducible results" |
| ) |
|
|
| use_random_seed = gr.Checkbox( |
| label="Random Seed", |
| value=False, |
| info="Generate random seed each time" |
| ) |
|
|
| generate_btn = gr.Button("✨ Generate Image", variant="primary", size="lg") |
|
|
| |
| with gr.Column(scale=1): |
| output_image = gr.Image( |
| label="Generated Image", |
| type="pil", |
| height=600 |
| ) |
| output_info = gr.Markdown("Generate an image to see details here.") |
|
|
| |
| load_btn.click( |
| fn=load_models, |
| inputs=[infinity_gguf, t5_gguf, vae_path, pn_load], |
| outputs=[load_status] |
| ) |
|
|
| generate_btn.click( |
| fn=generate_image_gradio, |
| inputs=[prompt, cfg_scale, tau, seed, aspect_ratio, pn, use_random_seed], |
| outputs=[output_image, output_info] |
| ) |
|
|
| return demo |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description='Infinity-2B GGUF Gradio Web UI') |
| parser.add_argument('--share', action='store_true', help='Create a public share link') |
| parser.add_argument('--server-name', type=str, default='127.0.0.1', help='Server name') |
| parser.add_argument('--server-port', type=int, default=7860, help='Server port') |
| parser.add_argument('--autoload', action='store_true', help='Auto-load models on startup') |
| parser.add_argument('--infinity-gguf', type=str, default='infinity_2b_reg_Q8_0.gguf') |
| parser.add_argument('--t5-gguf', type=str, default='flan-t5-xl-encoder-Q8_0.gguf') |
| parser.add_argument('--vae-path', type=str, default='Infinity/infinity_vae_d32_reg.pth') |
|
|
| args = parser.parse_args() |
|
|
| |
| if args.autoload: |
| print("Auto-loading models...") |
| load_models(args.infinity_gguf, args.t5_gguf, args.vae_path) |
|
|
| |
| demo = create_ui() |
|
|
| print("\n" + "="*70) |
| print("Starting Infinity-2B GGUF Web UI") |
| print("="*70) |
| print(f"Server: http://{args.server_name}:{args.server_port}") |
| if args.share: |
| print("Creating public share link...") |
| print("="*70 + "\n") |
|
|
| demo.launch( |
| server_name=args.server_name, |
| server_port=args.server_port, |
| share=args.share, |
| inbrowser=True |
| ) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|