File size: 1,715 Bytes
f1b29bc
99748b4
f1b29bc
36893e0
 
 
 
 
 
 
 
 
 
 
 
72a3fae
8aad493
f1b29bc
36893e0
 
 
 
99748b4
 
 
 
 
 
 
 
 
36893e0
 
 
99748b4
 
36893e0
 
 
 
 
 
1336e83
36893e0
 
 
 
 
 
 
 
 
 
 
 
 
 
f1b29bc
 
 
 
 
 
 
 
 
99748b4
f1b29bc
 
 
 
 
 
 
 
36893e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import os
import pickle

import streamlit as st

import torch
from torch import autocast
from diffusers import StableDiffusionPipeline


st.set_page_config(layout="wide")

st.title('Play with Stable-Diffusion v1-4')

model_id = "CompVis/stable-diffusion-v1-4"
device = "cuda" if torch.cuda.is_available() else "cpu"
auth_token = os.environ.get("StableDiffusion")


with st.spinner(
	text='Loading...'
):
	# pipe = StableDiffusionPipeline.from_pretrained(
	# 	model_id,
	# 	revision="fp16",
	# 	torch_dtype=torch.float16,
	# 	use_auth_token=auth_token
	# )
	with open('model/stable-diffusion.bin', 'rb') as model_file:
		pipe = pickle.load(model_file)
	pipe = pipe.to(device)


def infer(prompt, samples=2, steps=30, scale=7.5, seed=25):
	generator = torch.Generator(device=device).manual_seed(seed)
	# generator = torch.Generator().manual_seed(seed)

	with autocast("cuda"):
		images_list = pipe(
			[prompt] * samples,
			num_inference_steps=steps,
			guidance_scale=scale,
			generator=generator
		)

	images = []
	for image in images_list["sample"]:
		images.append(image)
	return images


with st.form(key='new'):

	prompt = st.text_area(label='Enter prompt')

	col1, col2, col3 = st.columns(3)

	with st.expander(label='Expand parameters'):
		n_samples = col1.select_slider(
			label='Num images',
			options=range(1, 5),
			value=1
		)

		steps = col2.select_slider(
			label='Steps',
			options=range(1, 101),
			value=40
		)

		scale = col3.select_slider(
			label='Guidance Scale',
			options=range(1, 21),
			value=7
		)

	st.form_submit_button()

if prompt:
	st.image(
		infer(
			prompt,
			samples=n_samples,
			steps=steps,
			scale=scale
		),
		caption='result'
	)
else:
	st.warning('Enter prompt.')