Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from diffusers import StableDiffusionPipeline | |
| from peft import PeftModel | |
| import os | |
| from PIL import Image | |
| import random | |
| class LoRAWebInterface: | |
| def __init__(self, base_model="runwayml/stable-diffusion-v1-5", lora_path="models/lora_model"): | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.lora_path = lora_path | |
| print("Loading models...") | |
| # Load base pipeline | |
| self.pipeline = StableDiffusionPipeline.from_pretrained( | |
| base_model, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| safety_checker=None, | |
| requires_safety_checker=False | |
| ) | |
| # Load LoRA weights if they exist | |
| if os.path.exists(lora_path): | |
| print(f"Loading LoRA model from {lora_path}") | |
| try: | |
| self.pipeline.unet = PeftModel.from_pretrained(self.pipeline.unet, lora_path) | |
| self.lora_loaded = True | |
| except Exception as e: | |
| print(f"Error loading LoRA: {e}") | |
| self.lora_loaded = False | |
| else: | |
| print("No LoRA model found, using base model") | |
| self.lora_loaded = False | |
| self.pipeline.to(self.device) | |
| # Enable memory efficient attention | |
| try: | |
| self.pipeline.enable_xformers_memory_efficient_attention() | |
| except: | |
| pass | |
| print("Model loaded successfully!") | |
| def generate_image(self, prompt, negative_prompt, num_steps, guidance_scale, | |
| width, height, seed, use_random_seed): | |
| """Generate image with given parameters""" | |
| if use_random_seed: | |
| seed = random.randint(0, 999999) | |
| if seed is not None and seed >= 0: | |
| torch.manual_seed(int(seed)) | |
| try: | |
| with torch.autocast(self.device.type): | |
| image = self.pipeline( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| num_inference_steps=int(num_steps), | |
| guidance_scale=guidance_scale, | |
| width=int(width), | |
| height=int(height) | |
| ).images[0] | |
| return image, f"✅ Generated successfully! Seed: {seed}" | |
| except Exception as e: | |
| error_msg = f"❌ Error generating image: {str(e)}" | |
| print(error_msg) | |
| # Return a blank image on error | |
| blank_image = Image.new('RGB', (512, 512), color='white') | |
| return blank_image, error_msg | |
| def create_interface(self): | |
| """Create Gradio interface""" | |
| with gr.Blocks(title="LoRA Image Generator", theme=gr.themes.Soft()) as interface: | |
| gr.Markdown("# 🎨 LoRA Image Generator") | |
| gr.Markdown(f"**Model Status:** {'✅ LoRA model loaded' if self.lora_loaded else '⚠️ Using base model only'}") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| # Input controls | |
| prompt = gr.Textbox( | |
| label="Prompt", | |
| placeholder="Describe the image you want to generate...", | |
| value="a beautiful artistic composition", | |
| lines=3 | |
| ) | |
| negative_prompt = gr.Textbox( | |
| label="Negative Prompt (Optional)", | |
| placeholder="Things you don't want in the image...", | |
| value="blurry, low quality, distorted", | |
| lines=2 | |
| ) | |
| with gr.Row(): | |
| num_steps = gr.Slider( | |
| minimum=10, | |
| maximum=100, | |
| value=50, | |
| step=5, | |
| label="Inference Steps" | |
| ) | |
| guidance_scale = gr.Slider( | |
| minimum=1.0, | |
| maximum=20.0, | |
| value=7.5, | |
| step=0.5, | |
| label="Guidance Scale" | |
| ) | |
| with gr.Row(): | |
| width = gr.Slider( | |
| minimum=256, | |
| maximum=1024, | |
| value=512, | |
| step=64, | |
| label="Width" | |
| ) | |
| height = gr.Slider( | |
| minimum=256, | |
| maximum=1024, | |
| value=512, | |
| step=64, | |
| label="Height" | |
| ) | |
| with gr.Row(): | |
| seed = gr.Number( | |
| label="Seed (-1 for random)", | |
| value=-1, | |
| precision=0 | |
| ) | |
| use_random_seed = gr.Checkbox( | |
| label="Use Random Seed", | |
| value=True | |
| ) | |
| generate_btn = gr.Button("🎨 Generate Image", variant="primary") | |
| with gr.Column(scale=1): | |
| # Output | |
| output_image = gr.Image( | |
| label="Generated Image", | |
| type="pil", | |
| height=512 | |
| ) | |
| status_text = gr.Textbox( | |
| label="Status", | |
| interactive=False, | |
| lines=2 | |
| ) | |
| # Example prompts | |
| gr.Markdown("## 💡 Example Prompts") | |
| example_prompts = [ | |
| "a serene landscape in artistic style", | |
| "abstract flowing patterns with vibrant colors", | |
| "geometric composition with soft lighting", | |
| "organic forms inspired by nature", | |
| "minimalist design with elegant curves" | |
| ] | |
| examples = gr.Examples( | |
| examples=[[prompt] for prompt in example_prompts], | |
| inputs=[prompt], | |
| label="Click an example to try:" | |
| ) | |
| # Event handlers | |
| generate_btn.click( | |
| fn=self.generate_image, | |
| inputs=[prompt, negative_prompt, num_steps, guidance_scale, | |
| width, height, seed, use_random_seed], | |
| outputs=[output_image, status_text] | |
| ) | |
| # Auto-disable seed input when random is selected | |
| use_random_seed.change( | |
| fn=lambda x: gr.update(interactive=not x), | |
| inputs=[use_random_seed], | |
| outputs=[seed] | |
| ) | |
| return interface | |
| def launch(self, share=False, server_port=7860): | |
| """Launch the interface""" | |
| interface = self.create_interface() | |
| interface.launch( | |
| share=share, | |
| server_port=server_port, | |
| inbrowser=True | |
| ) | |
| def main(): | |
| import argparse | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--lora_path", default="models/lora_model", help="Path to LoRA model") | |
| parser.add_argument("--share", action="store_true", help="Create public link") | |
| parser.add_argument("--port", type=int, default=7860, help="Server port") | |
| args = parser.parse_args() | |
| # Create and launch interface | |
| interface = LoRAWebInterface(lora_path=args.lora_path) | |
| interface.launch(share=args.share, server_port=args.port) | |
| if __name__ == "__main__": | |
| main() |