File size: 4,079 Bytes
adec894
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import tkinter as tk
from tkinter import scrolledtext, messagebox
import torch
from diffusers import DiffusionPipeline
import subprocess
import sys

# Function to install packages needed for the environment
def install_packages():
    try:
        subprocess.check_call([sys.executable, "-m", "pip", "install", "transformers", "accelerate", "safetensors"])
        messagebox.showinfo("Success", "Packages installed successfully!")
    except subprocess.CalledProcessError:
        messagebox.showerror("Error", "Failed to install packages.")

# Function to check and return appropriate device (GPU or iGPU)
def get_device():
    # Check if CUDA is available (for MX or other CUDA-enabled NVIDIA GPUs)
    if torch.cuda.is_available():
        # Get the name of the available GPU
        device_name = torch.cuda.get_device_name(0).lower()
        if "mx" in device_name:  # Prioritize MX GPUs if detected
            print(f"Using NVIDIA MX GPU: {device_name}")
            return "cuda"
        else:
            # If an NVIDIA GPU but not MX (could be a more powerful GPU), use it
            print(f"Using CUDA-enabled GPU: {device_name}")
            return "cuda"
    
    # If no CUDA GPU, check for Integrated GPU (iGPU) support
    # Integrated GPUs won't support CUDA, but we can still run inference without it
    else:
        print("No CUDA device detected. Trying to run with iGPU (integrated GPU).")
        return "cpu"  # Use CPU fallback for iGPU or lack of CUDA support

# Initialize the device (GPU or iGPU)
device = get_device()

# Load the diffusion pipeline and optimize it for the target device
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", use_safetensors=True, variant="fp16")
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
pipe.enable_model_cpu_offload()

# Load both base & refiner models to the target device (GPU or iGPU)
base = DiffusionPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0", variant="fp16", use_safetensors=True
)
base.to(device)

refiner = DiffusionPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-refiner-1.0",
    text_encoder_2=base.text_encoder_2,
    vae=base.vae,
    variant="fp16",
    use_safetensors=True,
)
refiner.to(device)

# Set parameters for inference
n_steps = 40
high_noise_frac = 0.8

def run_inference():
    main_prompt = main_prompt_text.get("1.0", tk.END).strip()
    negative_prompt = negative_prompt_text.get("1.0", tk.END).strip() if negative_prompt_text.get("1.0", tk.END).strip() else None
    
    # Run the pipeline with the prompts
    result = pipe(main_prompt, negative_prompt=negative_prompt)
    output_text.delete("1.0", tk.END)
    output_text.insert(tk.END, str(result))

    # Run both experts (base and refiner)
    image = base(
        prompt=main_prompt,
        num_inference_steps=n_steps,
        denoising_end=high_noise_frac,
        output_type="latent",
    ).images
    image = refiner(
        prompt=main_prompt,
        num_inference_steps=n_steps,
        denoising_start=high_noise_frac,
        image=image,
    ).images[0]

    # Display the image or result as needed
    output_text.insert(tk.END, "Image generated successfully.")

# Create the main window
root = tk.Tk()
root.title("Inference GUI")

# Main prompt input
tk.Label(root, text="Main Prompt:").pack()
main_prompt_text = scrolledtext.ScrolledText(root, wrap=tk.WORD, width=50, height=10)
main_prompt_text.pack()

# Negative prompt input
tk.Label(root, text="Negative Prompt (optional):").pack()
negative_prompt_text = scrolledtext.ScrolledText(root, wrap=tk.WORD, width=50, height=5)
negative_prompt_text.pack()

# Run button
run_button = tk.Button(root, text="Run Inference", command=run_inference)
run_button.pack()

# Install button
install_button = tk.Button(root, text="Install Packages", command=install_packages)
install_button.pack()

# Output text area
output_text = scrolledtext.ScrolledText(root, wrap=tk.WORD, width=50, height=10)
output_text.pack()

# Start the GUI event loop
root.mainloop()