rohithb commited on
Commit
7547682
·
1 Parent(s): af33b04

main app.py file.

Browse files
Files changed (1) hide show
  1. app.py +65 -0
app.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import sys
4
+ import torch
5
+ import gradio as gr
6
+ from stable_diffusion import StableDiffusion
7
+ from utils import invert_loss, get_style_embeddings, show_images
8
+
9
+ torch.manual_seed(1)
10
+ torch_device = "cuda" if torch.cuda.is_available() else "cpu"
11
+
12
+
13
+ def gradio_interface(prompt:str, style_name:str):
14
+ stable_diffuser = StableDiffusion(torch_device)
15
+ #outputs_1 = []
16
+ #outputs_2 = []
17
+ seed_values = [1,2,3,4,5]
18
+ custom_loss_scale = 150.0
19
+ num_styles = len(style_files)
20
+ style_file = style_files[style_name]
21
+ style_token_embedding = get_style_embeddings(style_file) if style_file is not None else None
22
+ this_generated_img_1 = stable_diffuser.generate_image_with_custom_style(prompt,
23
+ style_token_embedding = style_token_embedding,
24
+ random_seed = 42,
25
+ custom_loss_fn = None)
26
+ #outputs_1.append(this_generated_img_1)
27
+ this_generated_img_2 = stable_diffuser.generate_image_with_custom_style(prompt,
28
+ style_token_embedding = style_token_embedding,
29
+ random_seed = 42,
30
+ custom_loss_fn = invert_loss,
31
+ custom_loss_scale = custom_loss_scale)
32
+ #outputs_2.append(this_generated_img_2)
33
+ return this_generated_img_1, this_generated_img_2
34
+
35
+
36
+ style_files = {'style_1': None,
37
+ 'watercolor': 'learned_embeds_watercolor.bin',
38
+ 'strip': 'learned_embeds_strip_style.bin',
39
+ 'oil_paint': 'learned_embeds_oil_paint.bin',
40
+ 'kaleido': 'learned_embeds_kaleido.bin',
41
+ 'doodle': 'learned_embeds_doodle.bin'}
42
+
43
+ torch_device = "cuda" if torch.cuda.is_available() else "cpu"
44
+
45
+ prompt_examples = [["A cat wearing a party hat vivid colors", 'style_1']]
46
+ #["A dog wearing sunglasses on a skateboard"],
47
+ #["An oil painting of a bear playing guitar"]]
48
+ #["A tiger dressed in a hip hop styled clothing"],
49
+ #["An oil painting of a lion eating pizza"]]
50
+
51
+ # Define Interface
52
+ description = 'A Stable Diffusion based Generative AI tool to generate images. Also generates grayscale images.'
53
+ title = 'Image Generation using Stable Diffusion with styles.'
54
+ style_names = ['No_style','watercolor','strip','oil_paint','kaleido','doodle']
55
+ demo = gr.Interface(gradio_interface,
56
+ inputs = [gr.Textbox('A deer crossing the street', label="Text Prompt"),
57
+ gr.Dropdown(style_names,value='style_1',label="Display Style")],
58
+ outputs = [gr.Image(shape=(512, 512),label='Generated Image').style(width=512, height=512),
59
+ gr.Image(shape=(512, 512),label='Generated Images with custom loss').style(width=512, height=512)
60
+ ],
61
+ examples=prompt_examples,
62
+ title = title,
63
+ description = description
64
+ )
65
+ demo.launch(debug=True)