helms commited on
Commit
f5b2b3e
·
1 Parent(s): 559611d

Add application file

Browse files
Files changed (1) hide show
  1. app.py +153 -0
app.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from stability_sdk import client
3
+ import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation
4
+ from PIL import Image
5
+ import io
6
+ import os
7
+ import warnings
8
+ from dotenv import load_dotenv
9
+
10
+
11
+ # theme = gr.themes.Monochrome(
12
+ # primary_hue="indigo",
13
+ # secondary_hue="blue",
14
+ # neutral_hue="slate",
15
+ # radius_size=gr.themes.sizes.radius_sm,
16
+ # font=[gr.themes.GoogleFont("Open Sans"), "ui-sans-serif", "system-ui", "sans-serif"],
17
+ # )
18
+ load_dotenv()
19
+ SD_KEY = os.getenv("SD_KEY")
20
+ HF_KEY = os.getenv("HF_KEY")
21
+ USERNAME = os.getenv("USERNAME")
22
+ PASS = os.getenv("PASS")
23
+
24
+
25
+ hf_writer_gen = gr.HuggingFaceDatasetSaver(HF_KEY, "helms/master-thesis-generated-images", private=True, separate_dirs=False)
26
+ hf_writer_up = gr.HuggingFaceDatasetSaver(HF_KEY, "helms/master-thesis-upscaled-images", private=True, separate_dirs=False)
27
+
28
+
29
+ def infer(prompt):
30
+ stability_api = client.StabilityInference(
31
+ key=SD_KEY, # AaPI Key reference.
32
+ verbose=True, # Print debug messages.
33
+ engine="stable-diffusion-v1", # Set the engine to use for generation.
34
+ # Available engines: stable-diffusion-v1 stable-diffusion-v1-5 stable-diffusion-512-v2-0 stable-diffusion-768-v2-0 stable-inpainting-v1-0 stable-inpainting-512-v2-0
35
+ )
36
+ answers = stability_api.generate(
37
+ prompt=prompt,
38
+ # seed=992446758, # If a seed is provided, the resulting generated image will be deterministic.
39
+ # What this means is that as long as all generation parameters remain the same, you can always recall the same image simply by generating it again.
40
+ # Note: This isn't quite the case for Clip Guided generations, which we'll tackle in a future example notebook.
41
+ steps=30, # Amount of inference steps performed on image generation. Defaults to 30.
42
+ cfg_scale=7.0, # Influences how strongly your generation is guided to match your prompt.
43
+ # Setting this value higher increases the strength in which it tries to match your prompt.
44
+ # Defaults to 7.0 if not specified.
45
+ width=512, # Generation width, defaults to 512 if not included.
46
+ height=512, # Generation height, defaults to 512 if not included.
47
+ samples=4, # Number of images to generate, defaults to 1 if not included.
48
+ sampler=generation.SAMPLER_K_DPMPP_2M # Choose which sampler we want to denoise our generation with.
49
+ # Defaults to k_dpmpp_2m if not specified. Clip Guidance only supports ancestral samplers.
50
+ # (Available Samplers: ddim, plms, k_euler, k_euler_ancestral, k_heun, k_dpm_2, k_dpm_2_ancestral, k_dpmpp_2s_ancestral, k_lms, k_dpmpp_2m)
51
+ )
52
+ infer.results = []
53
+ for resp in answers:
54
+ for artifact in resp.artifacts:
55
+ if artifact.finish_reason == generation.FILTER:
56
+ warnings.warn(
57
+ "Your request activated the API's safety filters and could not be processed."
58
+ "Please modify the prompt and try again.")
59
+ if artifact.type == generation.ARTIFACT_IMAGE:
60
+ img = Image.open(io.BytesIO(artifact.binary))
61
+ infer.results.append(img)
62
+ #img.save(fp=str(prompt) + ".png")
63
+
64
+ return infer.results
65
+
66
+ def upscale(image):
67
+ # Set up our connection to the API.
68
+ stability_api = client.StabilityInference(
69
+ key=SD_KEY, # API Key reference.
70
+ upscale_engine="esrgan-v1-x2plus", # The name of the upscaling model we want to use.
71
+ # Available Upscaling Engines: esrgan-v1-x2plus, stable-diffusion-x4-latent-upscaler
72
+ verbose=True, # Print debug messages.
73
+ )
74
+ # Import our local image to use as a reference for our upscaled image.
75
+ # The 'img' variable below is set to a local file for upscaling, however if you are already running a generation call and have an image artifact available, you can pass that image artifact to the upscale function instead.
76
+ img = Image.open(image)
77
+
78
+ answers = stability_api.upscale(
79
+ init_image=img, # Pass our image to the API and call the upscaling process.
80
+ width=1024, # Optional parameter to specify the desired output width.
81
+ # prompt="A beautiful sunset", # Optional parameter when using `stable-diffusion-x4-latent-upscaler` to specify a prompt to use for the upscaling process.
82
+ # seed=1234, # Optional parameter when using `stable-diffusion-x4-latent-upscaler` to specify a seed to use for the upscaling process.
83
+ # steps=20, # Optional parameter when using `stable-diffusion-x4-latent-upscaler` to specify the number of diffusion steps to use for the upscaling process. Defaults to 20 if no value is passed, with a maximum of 50.
84
+ # cfg_scale=7 # Optional parameter when using `stable-diffusion-x4-latent-upscaler` to specify the strength of prompt in use for the upscaling process. Defaults to 7 if no value is passed.
85
+ )
86
+ # Set up our warning to print to the console if the adult content classifier is tripped.
87
+ # If adult content classifier is not tripped, save our image.
88
+
89
+ for resp in answers:
90
+ for artifact in resp.artifacts:
91
+ if artifact.finish_reason == generation.FILTER:
92
+ warnings.warn(
93
+ "Your request activated the API's safety filters and could not be processed."
94
+ "Please submit a different image and try again.")
95
+ if artifact.type == generation.ARTIFACT_IMAGE:
96
+ big_img = Image.open(io.BytesIO(artifact.binary))
97
+ #big_img.save("imageupscaled" + ".png") # Save our image to a local file.
98
+
99
+ return big_img
100
+
101
+
102
+ # demo = gr.Interface(
103
+ # infer,
104
+ # gr.Textbox(label="Input", lines=5),
105
+ # gr.Gallery(label="Output")
106
+ # )
107
+
108
+ with gr.Blocks(title="Master Thesis Image Generator") as demo:
109
+ gr.Markdown("Welcome to this demo app for the Master thesis of Fabian Helms, TU Dortmund")
110
+ with gr.Row():
111
+ with gr.Column():
112
+ user = gr.Textbox(placeholder="Please enter your name.", label="Name")
113
+ inp = gr.Textbox(placeholder="Enter your prompt here.", lines=3, label="Prompt")
114
+ btn_gen = gr.Button("Generate Images", variant="primary", size="lg")
115
+ with gr.Row(equal_height=True):
116
+ with gr.Column(scale=1):
117
+ out_gen = gr.Gallery(label="Generated Images", columns=2, rows=2, container=True, preview=False, height="2", allow_preview=True)
118
+ selected = gr.Image(label="selected img", type="filepath", interactive=False, visible=False)#.style(height=500, width=500)
119
+ with gr.Column(scale=1):
120
+ with gr.Row():
121
+ gr.Markdown(value="Please select the image you want to upscale.", visible=True)
122
+ btn_upscale = gr.Button("Upscale selected image", variant="secondary", size="sm")
123
+ with gr.Row():
124
+ out_upscale = gr.Image(label="Upscaled Image", interactive=False)#.style(height=500, width=500)
125
+ #btn_flag = gr.Button("Flag Image", variant="secondary", size="sm")
126
+
127
+ hf_writer_gen.setup([user, inp, out_gen], ".temp_gen")
128
+ hf_writer_up.setup([user, inp, out_upscale], ".temp_upscale")
129
+ def select_img(evt: gr.SelectData):
130
+ sel_img = infer.results[evt.index]
131
+ #hf_writer.flag([inp, user, out_gen])
132
+ #sel_img.save("check.png")
133
+ return sel_img
134
+
135
+ def gen(prompt):
136
+ res_gen = infer(prompt)
137
+ return out_gen.update(value=res_gen)
138
+
139
+
140
+ inp.submit(fn=infer, inputs=inp, outputs=out_gen
141
+ ).then(lambda *args: hf_writer_gen.flag(args), inputs=[user, inp, out_gen], outputs=None, preprocess=False)
142
+ btn_gen.click(fn=infer, inputs=inp, outputs=out_gen
143
+ ).then(lambda *args: hf_writer_gen.flag(args), inputs=[user, inp, out_gen], outputs=None, preprocess=False)
144
+
145
+ btn_upscale.click(fn=upscale, inputs=selected, outputs=out_upscale
146
+ ).then(lambda *args: hf_writer_up.flag(args), inputs=[user, inp, out_upscale], outputs=None, preprocess=False)
147
+
148
+ out_gen.select(select_img, outputs=selected)
149
+ #btn_flag.click(lambda *args: hf_writer.flag(args), inputs=[user, inp, out_gen], outputs=None, preprocess=False)
150
+
151
+
152
+
153
+ demo.launch(show_api=False, auth=(USERNAME, PASS))