mochifz commited on
Commit
a752c0b
·
verified ·
1 Parent(s): b47bf2f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -15
app.py CHANGED
@@ -1,10 +1,10 @@
1
  import gradio as gr
2
  import torch, random, time
3
  from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
4
- device = "cuda" if torch.cuda.is_available() else "cpu"
5
  translations = {
6
  'en': {
7
- 'model': 'Model Path',
8
  'loading': 'Loading',
9
  'input': 'Input Image',
10
  'prompt': 'Prompt',
@@ -18,7 +18,7 @@ translations = {
18
  'seed': 'Seed',
19
  },
20
  'zh': {
21
- 'model': '模型路径',
22
  'loading': '载入',
23
  'input': '输入图像',
24
  'prompt': '提示',
@@ -37,7 +37,7 @@ def generate_new_seed():
37
  return random.randint(1, 2147483647)
38
  def update_language(new_language):
39
  return [
40
- gr.Textbox.update(placeholder=translations[new_language]['model']),
41
  gr.Button.update(value=translations[new_language]['loading']),
42
  gr.Image.update(label=translations[new_language]['input']),
43
  gr.Textbox.update(placeholder=translations[new_language]['prompt']),
@@ -56,7 +56,7 @@ img2img = None
56
  def Generate(image_input, prompt, negative_prompt, strength, guidance_scale, num_inference_steps, width, height, seed):
57
  if seed == -1:
58
  seed = generate_new_seed()
59
- generator = torch.Generator(device).manual_seed(int(seed))
60
  global text2img, img2img
61
  start_time = time.time()
62
  if image_input is None:
@@ -65,20 +65,25 @@ def Generate(image_input, prompt, negative_prompt, strength, guidance_scale, num
65
  image = img2img(image=image_input, strength=0.75, prompt=prompt, negative_prompt=negative_prompt, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, width=width, height=height, num_images_per_prompt=1, generator=generator).images[0]
66
  minutes, seconds = divmod(round(time.time() - start_time), 60)
67
  return image, f"{minutes:02d}:{seconds:02d}"
68
- def Loading(model):
69
  global text2img, img2img
70
- if device == "cuda":
71
- text2img = StableDiffusionPipeline.from_pretrained(model, torch_dtype=torch.float16, variant="fp16", use_safetensors=True).to(device)
 
 
 
 
 
72
  text2img.enable_xformers_memory_efficient_attention()
73
  text2img.vae.enable_xformers_memory_efficient_attention()
74
- else:
75
- text2img = StableDiffusionPipeline.from_pretrained(model, use_safetensors=True).to(device)
76
  text2img.safety_checker = None
77
- img2img = StableDiffusionImg2ImgPipeline(**text2img.components)
78
- return model
79
  with gr.Blocks() as demo:
80
  with gr.Row():
81
- model = gr.Textbox(value="nota-ai/bk-sdm-tiny-2m", label=translations[language]['model'])
 
 
82
  loading = gr.Button(translations[language]['loading'])
83
  set_language = gr.Dropdown(list(translations.keys()), label="Language", value=language)
84
  with gr.Row():
@@ -104,6 +109,6 @@ with gr.Blocks() as demo:
104
  text_output = gr.Textbox(label="time")
105
  set_seed.click(generate_new_seed, None, seed)
106
  generate.click(Generate, [image_input, prompt, negative_prompt, strength, guidance_scale, num_inference_steps, width, height, seed], [image_output, text_output])
107
- loading.click(Loading, model, model)
108
- set_language.change(update_language, set_language, [model, loading, image_input, prompt, negative_prompt, generate, strength, guidance_scale, num_inference_steps, width, height, seed])
109
  demo.queue().launch()
 
1
  import gradio as gr
2
  import torch, random, time
3
  from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
4
+ from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
5
  translations = {
6
  'en': {
7
+ 'model_name': 'Model Path',
8
  'loading': 'Loading',
9
  'input': 'Input Image',
10
  'prompt': 'Prompt',
 
18
  'seed': 'Seed',
19
  },
20
  'zh': {
21
+ 'model_name': '模型路径',
22
  'loading': '载入',
23
  'input': '输入图像',
24
  'prompt': '提示',
 
37
  return random.randint(1, 2147483647)
38
  def update_language(new_language):
39
  return [
40
+ gr.Textbox.update(placeholder=translations[new_language]['model_name']),
41
  gr.Button.update(value=translations[new_language]['loading']),
42
  gr.Image.update(label=translations[new_language]['input']),
43
  gr.Textbox.update(placeholder=translations[new_language]['prompt']),
 
56
  def Generate(image_input, prompt, negative_prompt, strength, guidance_scale, num_inference_steps, width, height, seed):
57
  if seed == -1:
58
  seed = generate_new_seed()
59
+ generator = torch.Generator().manual_seed(int(seed))
60
  global text2img, img2img
61
  start_time = time.time()
62
  if image_input is None:
 
65
  image = img2img(image=image_input, strength=0.75, prompt=prompt, negative_prompt=negative_prompt, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, width=width, height=height, num_images_per_prompt=1, generator=generator).images[0]
66
  minutes, seconds = divmod(round(time.time() - start_time), 60)
67
  return image, f"{minutes:02d}:{seconds:02d}"
68
+ def Loading(model_name, is_xl, is_cuda):
69
  global text2img, img2img
70
+ device = "cuda" if is_xl else "cpu"
71
+ pipeline_class = StableDiffusionXLPipeline if is_xl else StableDiffusionPipeline
72
+ if is_xl:
73
+ text2img = pipeline_class.from_pretrained(model_name, torch_dtype=torch.float16, variant="fp16", use_safetensors=True).to(device)
74
+ else:
75
+ text2img = pipeline_class.from_pretrained(model_name, use_safetensors=True).to(device)
76
+ if is_xl:
77
  text2img.enable_xformers_memory_efficient_attention()
78
  text2img.vae.enable_xformers_memory_efficient_attention()
 
 
79
  text2img.safety_checker = None
80
+ img2img = (StableDiffusionXLImg2ImgPipeline if is_xl else StableDiffusionImg2ImgPipeline)(**text2img.components)
81
+ return model_name
82
  with gr.Blocks() as demo:
83
  with gr.Row():
84
+ model_name = gr.Textbox(value="nota-ai/bk-sdm-tiny-2m", label=translations[language]['model_name'])
85
+ is_xl = gr.Checkbox(label="SDXL")
86
+ is_cuda = gr.Checkbox(label="cuda", value=torch.cuda.is_available())
87
  loading = gr.Button(translations[language]['loading'])
88
  set_language = gr.Dropdown(list(translations.keys()), label="Language", value=language)
89
  with gr.Row():
 
109
  text_output = gr.Textbox(label="time")
110
  set_seed.click(generate_new_seed, None, seed)
111
  generate.click(Generate, [image_input, prompt, negative_prompt, strength, guidance_scale, num_inference_steps, width, height, seed], [image_output, text_output])
112
+ loading.click(Loading, [model_name, is_xl, is_cuda], model_name)
113
+ set_language.change(update_language, set_language, [model_name, loading, image_input, prompt, negative_prompt, generate, strength, guidance_scale, num_inference_steps, width, height, seed])
114
  demo.queue().launch()