import gradio as gr import torch from diffusers import StableDiffusionPipeline from peft import PeftModel import os from PIL import Image import random class LoRAWebInterface: def __init__(self, base_model="runwayml/stable-diffusion-v1-5", lora_path="models/lora_model"): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.lora_path = lora_path print("Loading models...") # Load base pipeline self.pipeline = StableDiffusionPipeline.from_pretrained( base_model, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, safety_checker=None, requires_safety_checker=False ) # Load LoRA weights if they exist if os.path.exists(lora_path): print(f"Loading LoRA model from {lora_path}") try: self.pipeline.unet = PeftModel.from_pretrained(self.pipeline.unet, lora_path) self.lora_loaded = True except Exception as e: print(f"Error loading LoRA: {e}") self.lora_loaded = False else: print("No LoRA model found, using base model") self.lora_loaded = False self.pipeline.to(self.device) # Enable memory efficient attention try: self.pipeline.enable_xformers_memory_efficient_attention() except: pass print("Model loaded successfully!") def generate_image(self, prompt, negative_prompt, num_steps, guidance_scale, width, height, seed, use_random_seed): """Generate image with given parameters""" if use_random_seed: seed = random.randint(0, 999999) if seed is not None and seed >= 0: torch.manual_seed(int(seed)) try: with torch.autocast(self.device.type): image = self.pipeline( prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=int(num_steps), guidance_scale=guidance_scale, width=int(width), height=int(height) ).images[0] return image, f"✅ Generated successfully! Seed: {seed}" except Exception as e: error_msg = f"❌ Error generating image: {str(e)}" print(error_msg) # Return a blank image on error blank_image = Image.new('RGB', (512, 512), color='white') return blank_image, error_msg def create_interface(self): """Create Gradio interface""" with gr.Blocks(title="LoRA Image Generator", theme=gr.themes.Soft()) as interface: gr.Markdown("# 🎨 LoRA Image Generator") gr.Markdown(f"**Model Status:** {'✅ LoRA model loaded' if self.lora_loaded else '⚠️ Using base model only'}") with gr.Row(): with gr.Column(scale=1): # Input controls prompt = gr.Textbox( label="Prompt", placeholder="Describe the image you want to generate...", value="a beautiful artistic composition", lines=3 ) negative_prompt = gr.Textbox( label="Negative Prompt (Optional)", placeholder="Things you don't want in the image...", value="blurry, low quality, distorted", lines=2 ) with gr.Row(): num_steps = gr.Slider( minimum=10, maximum=100, value=50, step=5, label="Inference Steps" ) guidance_scale = gr.Slider( minimum=1.0, maximum=20.0, value=7.5, step=0.5, label="Guidance Scale" ) with gr.Row(): width = gr.Slider( minimum=256, maximum=1024, value=512, step=64, label="Width" ) height = gr.Slider( minimum=256, maximum=1024, value=512, step=64, label="Height" ) with gr.Row(): seed = gr.Number( label="Seed (-1 for random)", value=-1, precision=0 ) use_random_seed = gr.Checkbox( label="Use Random Seed", value=True ) generate_btn = gr.Button("🎨 Generate Image", variant="primary") with gr.Column(scale=1): # Output output_image = gr.Image( label="Generated Image", type="pil", height=512 ) status_text = gr.Textbox( label="Status", interactive=False, lines=2 ) # Example prompts gr.Markdown("## 💡 Example Prompts") example_prompts = [ "a serene landscape in artistic style", "abstract flowing patterns with vibrant colors", "geometric composition with soft lighting", "organic forms inspired by nature", "minimalist design with elegant curves" ] examples = gr.Examples( examples=[[prompt] for prompt in example_prompts], inputs=[prompt], label="Click an example to try:" ) # Event handlers generate_btn.click( fn=self.generate_image, inputs=[prompt, negative_prompt, num_steps, guidance_scale, width, height, seed, use_random_seed], outputs=[output_image, status_text] ) # Auto-disable seed input when random is selected use_random_seed.change( fn=lambda x: gr.update(interactive=not x), inputs=[use_random_seed], outputs=[seed] ) return interface def launch(self, share=False, server_port=7860): """Launch the interface""" interface = self.create_interface() interface.launch( share=share, server_port=server_port, inbrowser=True ) def main(): import argparse parser = argparse.ArgumentParser() parser.add_argument("--lora_path", default="models/lora_model", help="Path to LoRA model") parser.add_argument("--share", action="store_true", help="Create public link") parser.add_argument("--port", type=int, default=7860, help="Server port") args = parser.parse_args() # Create and launch interface interface = LoRAWebInterface(lora_path=args.lora_path) interface.launch(share=args.share, server_port=args.port) if __name__ == "__main__": main()