File size: 1,870 Bytes
5ad7ffe
 
 
 
5f17680
5ad7ffe
 
5f17680
5ad7ffe
 
 
 
 
 
 
 
 
5f17680
 
 
 
 
5ad7ffe
 
 
 
 
5f17680
5ad7ffe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5f17680
 
 
 
 
5ad7ffe
 
 
5f17680
5ad7ffe
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import os
import gradio as gr
import torch
import uuid
import peft
from PIL import Image
from diffusers import AutoPipelineForText2Image, StableDiffusionXLInpaintPipeline, StableDiffusionXLPipeline
from peft import PeftModel, PeftConfig

# Define global variables
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
lora_models = {}
trigger_word = {}

# Load the pretrained model and add LoRAs
pipe = AutoPipelineForText2Image.from_pretrained(model_id, torch_dtype=torch.float16, variant="fp16")
pipe.to("cuda")

base_model = pipe.model
peft_config = PeftConfig.from_pretrained('lora_weights/qwe_cat_long.safetensors')
peft_model = PeftModel.from_pretrained(base_model, peft_config)
pipe.model = peft_model

# Create a dictionary of available LoRAs and their corresponding trigger words
for i in os.scandir('lora_weights'):
    if i.name != '.gitignore':
        lora_models[i.name] = i.path
        trigger_word[i.name] = i.name.split('_')[0] + ' cat bright white fur'

# Define helper functions
def save_img(image_list, prompt):
    results_folder = 'results/'
    os.makedirs(results_folder, exist_ok=True)
    for image in image_list:
        image = Image.open(image[0])
        unique_id = uuid.uuid4()
        image.save(f"{results_folder}{unique_id}.jpg")
        new_filename = f"{results_folder}{unique_id}.txt"
        with open(new_filename, "w") as file:
            file.write(prompt)

def set_lora_model(lora_name, lora_scale):
    pipe.unfuse_lora(True)
    pipe.unload_lora_weights()
    print(lora_models[lora_name])
    peft_config = PeftConfig.from_pretrained(lora_models[lora_name])
    peft_config.lora_scale = lora_scale
    peft_model = PeftModel.from_pretrained(base_model, peft_config)
    pipe.model = peft_model
    pipe.fuse_lora()
    print('Model swapped')
    return trigger_word[lora_name]

# ...

if __name__ == "__main__":
    main()