ACT-images / models /web_interface.py
DPM1987's picture
Upload folder using huggingface_hub
df83e9f verified
import gradio as gr
import torch
from diffusers import StableDiffusionPipeline
from peft import PeftModel
import os
from PIL import Image
import random
class LoRAWebInterface:
def __init__(self, base_model="runwayml/stable-diffusion-v1-5", lora_path="models/lora_model"):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.lora_path = lora_path
print("Loading models...")
# Load base pipeline
self.pipeline = StableDiffusionPipeline.from_pretrained(
base_model,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
safety_checker=None,
requires_safety_checker=False
)
# Load LoRA weights if they exist
if os.path.exists(lora_path):
print(f"Loading LoRA model from {lora_path}")
try:
self.pipeline.unet = PeftModel.from_pretrained(self.pipeline.unet, lora_path)
self.lora_loaded = True
except Exception as e:
print(f"Error loading LoRA: {e}")
self.lora_loaded = False
else:
print("No LoRA model found, using base model")
self.lora_loaded = False
self.pipeline.to(self.device)
# Enable memory efficient attention
try:
self.pipeline.enable_xformers_memory_efficient_attention()
except:
pass
print("Model loaded successfully!")
def generate_image(self, prompt, negative_prompt, num_steps, guidance_scale,
width, height, seed, use_random_seed):
"""Generate image with given parameters"""
if use_random_seed:
seed = random.randint(0, 999999)
if seed is not None and seed >= 0:
torch.manual_seed(int(seed))
try:
with torch.autocast(self.device.type):
image = self.pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=int(num_steps),
guidance_scale=guidance_scale,
width=int(width),
height=int(height)
).images[0]
return image, f"✅ Generated successfully! Seed: {seed}"
except Exception as e:
error_msg = f"❌ Error generating image: {str(e)}"
print(error_msg)
# Return a blank image on error
blank_image = Image.new('RGB', (512, 512), color='white')
return blank_image, error_msg
def create_interface(self):
"""Create Gradio interface"""
with gr.Blocks(title="LoRA Image Generator", theme=gr.themes.Soft()) as interface:
gr.Markdown("# 🎨 LoRA Image Generator")
gr.Markdown(f"**Model Status:** {'✅ LoRA model loaded' if self.lora_loaded else '⚠️ Using base model only'}")
with gr.Row():
with gr.Column(scale=1):
# Input controls
prompt = gr.Textbox(
label="Prompt",
placeholder="Describe the image you want to generate...",
value="a beautiful artistic composition",
lines=3
)
negative_prompt = gr.Textbox(
label="Negative Prompt (Optional)",
placeholder="Things you don't want in the image...",
value="blurry, low quality, distorted",
lines=2
)
with gr.Row():
num_steps = gr.Slider(
minimum=10,
maximum=100,
value=50,
step=5,
label="Inference Steps"
)
guidance_scale = gr.Slider(
minimum=1.0,
maximum=20.0,
value=7.5,
step=0.5,
label="Guidance Scale"
)
with gr.Row():
width = gr.Slider(
minimum=256,
maximum=1024,
value=512,
step=64,
label="Width"
)
height = gr.Slider(
minimum=256,
maximum=1024,
value=512,
step=64,
label="Height"
)
with gr.Row():
seed = gr.Number(
label="Seed (-1 for random)",
value=-1,
precision=0
)
use_random_seed = gr.Checkbox(
label="Use Random Seed",
value=True
)
generate_btn = gr.Button("🎨 Generate Image", variant="primary")
with gr.Column(scale=1):
# Output
output_image = gr.Image(
label="Generated Image",
type="pil",
height=512
)
status_text = gr.Textbox(
label="Status",
interactive=False,
lines=2
)
# Example prompts
gr.Markdown("## 💡 Example Prompts")
example_prompts = [
"a serene landscape in artistic style",
"abstract flowing patterns with vibrant colors",
"geometric composition with soft lighting",
"organic forms inspired by nature",
"minimalist design with elegant curves"
]
examples = gr.Examples(
examples=[[prompt] for prompt in example_prompts],
inputs=[prompt],
label="Click an example to try:"
)
# Event handlers
generate_btn.click(
fn=self.generate_image,
inputs=[prompt, negative_prompt, num_steps, guidance_scale,
width, height, seed, use_random_seed],
outputs=[output_image, status_text]
)
# Auto-disable seed input when random is selected
use_random_seed.change(
fn=lambda x: gr.update(interactive=not x),
inputs=[use_random_seed],
outputs=[seed]
)
return interface
def launch(self, share=False, server_port=7860):
"""Launch the interface"""
interface = self.create_interface()
interface.launch(
share=share,
server_port=server_port,
inbrowser=True
)
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--lora_path", default="models/lora_model", help="Path to LoRA model")
parser.add_argument("--share", action="store_true", help="Create public link")
parser.add_argument("--port", type=int, default=7860, help="Server port")
args = parser.parse_args()
# Create and launch interface
interface = LoRAWebInterface(lora_path=args.lora_path)
interface.launch(share=args.share, server_port=args.port)
if __name__ == "__main__":
main()