import gc from functools import partial import gradio as gr import torch from os import getenv from PIL.Image import Image from diffusers import DiffusionPipeline from utils import get_pytorch_device, spaces_gpu, get_torch_dtype @spaces_gpu def text_to_image(model: str, prompt: str) -> Image: """Generate an image from a text prompt using a diffusion model. This function uses a diffusion pipeline (e.g., Stable Diffusion, FLUX) to generate images from text prompts. The model is loaded, inference is performed, and then cleaned up to free GPU memory. Args: model: Hugging Face model ID to use for text-to-image generation. prompt: Text description of the desired image. Returns: PIL Image object representing the generated image. Note: - Uses safetensors for secure model loading. - Automatically selects the best available device (CUDA/XPU/MPS/CPU). - Cleans up model and GPU memory after inference. """ pytorch_device = get_pytorch_device() dtype = get_torch_dtype() # During inference or evaluation, gradient calculations are unnecessary. Using torch.no_grad() # reduces memory consumption by not storing gradients. This can significantly reduce the # amount of memory used during the inference phase. pipe = DiffusionPipeline.from_pretrained( model, use_safetensors=True, dtype=dtype ) pipe = pipe.to(pytorch_device) with torch.no_grad(): result = pipe(prompt).images[0] # Clean up GPU memory del pipe if pytorch_device == "cuda": torch.cuda.empty_cache() gc.collect() return result def create_text_to_image_tab(model: str): """Create the text-to-image generation tab in the Gradio interface. This function sets up all UI components for text-to-image generation, including input textbox, generate button, and output image display. Args: model: Hugging Face model ID to use for text-to-image generation. """ gr.Markdown("Generate an image from a text prompt.") text_to_image_prompt = gr.Textbox(label="Prompt") text_to_image_generate_button = gr.Button("Generate") text_to_image_output = gr.Image(label="Image", type="pil") text_to_image_generate_button.click( fn=partial(text_to_image, model), inputs=text_to_image_prompt, outputs=text_to_image_output )