import gradio as gr import torch from model_handler import ModelHandler from utils import get_random_seed # Initialize the model handler # We initialize it here to load the model when the app starts model_handler = ModelHandler() def generate( prompt, negative_prompt, width, height, steps, guidance_scale, seed, progress=gr.Progress() ): """ Wrapper function to call the model inference. """ if seed < 0: seed = get_random_seed() try: image = model_handler.infer( prompt=prompt, negative_prompt=negative_prompt, width=width, height=height, num_inference_steps=steps, guidance_scale=guidance_scale, seed=seed, progress_callback=progress ) return image, seed except Exception as e: raise gr.Error(f"Generation failed: {str(e)}") # CSS for custom styling css = """ .container { max-width: 900px; margin: auto; } .header { text-align: center; margin-bottom: 20px; } .header h1 { font-size: 2.5rem; font-weight: bold; color: #333; } .header p { font-size: 1.1rem; color: #666; } .footer { text-align: center; margin-top: 20px; font-size: 0.9rem; } """ # Create the Gradio Interface with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: with gr.Column(elem_classes="container"): # Header with gr.Column(elem_classes="header"): gr.Markdown( """ # Kandinsky 5.0 Lite T2I (SFT) ### Text-to-Image Generation """ ) gr.Markdown("[Built with anycoder](https://huggingface.co/spaces/akhaliq/anycoder)") # Status info for hardware device_info = "Running on **GPU** 🚀" if torch.cuda.is_available() else "Running on **CPU** ⚠️ (Inference will be slow)" gr.Markdown(device_info) with gr.Row(): # Left Column: Inputs with gr.Column(scale=1): prompt = gr.Textbox( label="Prompt", placeholder="Describe the image you want to generate...", lines=3, autofocus=True ) negative_prompt = gr.Textbox( label="Negative Prompt", placeholder="Low quality, bad anatomy, blurry...", lines=2, value="low quality, bad anatomy, worst quality, deformed, disfigured" ) with gr.Accordion("Advanced Settings", open=False): with gr.Row(): width = gr.Slider(label="Width", minimum=256, maximum=1024, step=64, value=1024) height = gr.Slider(label="Height", minimum=256, maximum=1024, step=64, value=1024) steps = gr.Slider(label="Inference Steps", minimum=10, maximum=100, step=1, value=25) guidance_scale = gr.Slider(label="Guidance Scale", minimum=1.0, maximum=20.0, step=0.5, value=7.5) with gr.Row(): seed = gr.Number(label="Seed", value=-1, precision=0, info="Set to -1 for random") random_btn = gr.Button("🎲 Randomize", size="sm", variant="secondary") run_btn = gr.Button("Generate Image", variant="primary", size="lg") # Right Column: Output with gr.Column(scale=1): result_image = gr.Image(label="Generated Image", type="pil", interactive=False) used_seed = gr.Number(label="Seed Used", interactive=False) # Event Handlers run_btn.click( fn=generate, inputs=[prompt, negative_prompt, width, height, steps, guidance_scale, seed], outputs=[result_image, used_seed] ) # Helper to randomize seed input visually random_btn.click(lambda: -1, outputs=seed) # Examples gr.Examples( examples=[ ["A futuristic cityscape with neon lights and flying cars, cyberpunk style, high detail", "low quality, blurry", 1024, 1024, 25, 7.5], ["A cute red panda drinking coffee in a cozy cafe, digital art", "deformed, ugly", 1024, 1024, 25, 7.0], ["Portrait of a warrior princess, intricate armor, dramatic lighting, photorealistic", "cartoon, sketch, monochrome", 1024, 1024, 30, 8.0] ], inputs=[prompt, negative_prompt, width, height, steps, guidance_scale], fn=generate, outputs=[result_image, used_seed], cache_examples=False ) if __name__ == "__main__": demo.launch()