Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import spaces | |
| import torch | |
| from diffusers import HiDreamImagePipeline | |
| from transformers import PreTrainedTokenizerFast, LlamaForCausalLM | |
| import random | |
| import numpy as np | |
| # Set data type | |
| dtype = torch.bfloat16 | |
| device = "cpu" # Use CPU for model loading to avoid CUDA initialization | |
| # Load tokenizer and text encoder for Llama | |
| try: | |
| tokenizer_4 = PreTrainedTokenizerFast.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct") | |
| text_encoder_4 = LlamaForCausalLM.from_pretrained( | |
| "meta-llama/Meta-Llama-3.1-8B-Instruct", | |
| output_hidden_states=True, | |
| output_attentions=True, | |
| torch_dtype=dtype, | |
| ).to(device) | |
| except Exception as e: | |
| raise Exception(f"Failed to load Llama model: {e}. Ensure you have access to 'meta-llama/Meta-Llama-3.1-8B-Instruct' and are logged in via `huggingface-cli login`.") | |
| # Load the HiDreamImagePipeline | |
| try: | |
| pipe = HiDreamImagePipeline.from_pretrained( | |
| "HiDream-ai/HiDream-I1-Dev", | |
| tokenizer_4=tokenizer_4, | |
| text_encoder_4=text_encoder_4, | |
| torch_dtype=dtype, | |
| ).to(device) | |
| pipe.enable_model_cpu_offload() # Offload to CPU, automatically manages GPU placement | |
| except Exception as e: | |
| raise Exception(f"Failed to load HiDreamImagePipeline: {e}. Ensure you have access to 'HiDream-ai/HiDream-I1-Full'.") | |
| # Define maximum values | |
| MAX_SEED = np.iinfo(np.int32).max | |
| MAX_IMAGE_SIZE = 2048 | |
| # Inference function with GPU access | |
| def infer(prompt, negative_prompt="", seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=28, guidance_scale=3.5, progress=gr.Progress(track_tqdm=True)): | |
| try: | |
| if randomize_seed: | |
| seed = random.randint(0, MAX_SEED) | |
| generator = torch.Generator("cuda").manual_seed(seed) | |
| # Generate the image; offloading handles device placement | |
| image = pipe( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| height=height, | |
| width=width, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| generator=generator, | |
| output_type="pil", | |
| ).images[0] | |
| return image, seed | |
| finally: | |
| # Clear GPU memory | |
| torch.cuda.empty_cache() | |
| # Define examples | |
| examples = [ | |
| ["A cat holding a sign that says \"Hi-Dreams.ai\".", ""], | |
| ["A futuristic cityscape with flying cars.", "blurry, low quality"], | |
| ["A serene landscape with mountains and a lake.", ""], | |
| ] | |
| # CSS styling | |
| css = """ | |
| #col-container { | |
| margin: 0 auto; | |
| max-width: 960px; | |
| } | |
| .generate-btn { | |
| background: linear-gradient(90deg, #4B79A1 0%, #283E51 100%) !important; | |
| border: none !important; | |
| color: white !important; | |
| } | |
| .generate-btn:hover { | |
| transform: translateY2px); | |
| box-shadow: 0 5px 15px rgba(0,0,0,0.2); | |
| } | |
| """ | |
| # Create Gradio interface | |
| with gr.Blocks(css=css) as app: | |
| gr.HTML("<center><h1>HiDreamImage Generator</h1></center>") | |
| with gr.Column(elem_id="col-container"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Row(): | |
| text_prompt = gr.Textbox( | |
| label="Prompt", | |
| placeholder="Enter a prompt here", | |
| lines=3, | |
| elem_id="prompt-text-input" | |
| ) | |
| with gr.Row(): | |
| negative_prompt = gr.Textbox( | |
| label="Negative Prompt", | |
| placeholder="Enter what to avoid (optional)", | |
| lines=2 | |
| ) | |
| with gr.Row(): | |
| with gr.Accordion("Advanced Settings", open=False): | |
| with gr.Row(): | |
| width = gr.Slider( | |
| label="Width", | |
| value=1024, | |
| minimum=64, | |
| maximum=MAX_IMAGE_SIZE, | |
| step=8 | |
| ) | |
| height = gr.Slider( | |
| label="Height", | |
| value=1024, | |
| minimum=64, | |
| maximum=MAX_IMAGE_SIZE, | |
| step=8 | |
| ) | |
| with gr.Row(): | |
| steps = gr.Slider( | |
| label="Inference Steps", | |
| value=28, | |
| minimum=1, | |
| maximum=100, | |
| step=1 | |
| ) | |
| cfg = gr.Slider( | |
| label="Guidance Scale", | |
| value=3.5, | |
| minimum=1, | |
| maximum=20, | |
| step=0.5 | |
| ) | |
| with gr.Row(): | |
| seed = gr.Slider( | |
| label="Seed", | |
| value=42, | |
| minimum=0, | |
| maximum=MAX_SEED, | |
| step=1 | |
| ) | |
| randomize_seed = gr.Checkbox( | |
| label="Randomize Seed", | |
| value=True | |
| ) | |
| with gr.Row(): | |
| text_button = gr.Button( | |
| "✨ Generate Image", | |
| variant='primary', | |
| elem_classes=["generate-btn"] | |
| ) | |
| with gr.Column(): | |
| with gr.Row(): | |
| image_output = gr.Image( | |
| type="pil", | |
| label="Generated Image", | |
| elem_id="gallery" | |
| ) | |
| seed_output = gr.Textbox( | |
| label="Seed Used", | |
| interactive=False | |
| ) | |
| with gr.Column(): | |
| gr.Examples( | |
| examples=examples, | |
| inputs=[text_prompt, negative_prompt], | |
| ) | |
| # Connect the button and textbox submit to the inference function | |
| gr.on( | |
| triggers=[text_button.click, text_prompt.submit], | |
| fn=infer, | |
| inputs=[text_prompt, negative_prompt, seed, randomize_seed, width, height, steps, cfg], | |
| outputs=[image_output, seed_output] | |
| ) | |
| app.launch(share=True) |