gstaff commited on
Commit
d8e142e
·
1 Parent(s): 62a8602

Replace DALLE-mini with Stable Diffusion pipeline.

Browse files
Files changed (2) hide show
  1. app.py +32 -23
  2. requirements.txt +0 -0
app.py CHANGED
@@ -12,9 +12,10 @@ from fastai.callback.core import Callback
12
  from fastai.learner import *
13
  from fastai.torch_core import TitledStr
14
  from html2image import Html2Image
15
- from min_dalle import MinDalle
16
  from torch import tensor, Tensor, float16, float32
17
  from torch.distributions import Transform
 
18
 
19
  # These utility functions need to be in main (or otherwise where created) because fastai loads from that module, see:
20
  # https://docs.fast.ai/learner.html#load_learner
@@ -50,34 +51,41 @@ class DropOutput(Callback):
50
 
51
  # initialize only once
52
  # Takes about 2 minutes (126 seconds) to generate an image in Huggingface spaces on CPU
53
- model = MinDalle(
54
- models_root='./pretrained',
55
- dtype=float32,
56
- device='cpu',
57
- is_mega=True,
58
- is_reusable=True
59
- )
 
 
 
 
60
 
61
 
62
  def gen_image(prompt):
 
63
  # See https://huggingface.co/spaces/pootow/min-dalle/blob/main/app.py
64
  # Hugging Space faces seems to run out of memory if grads are not disabled
65
- torch.set_grad_enabled(False)
66
  print(f'RUNNING gen_image with prompt: {prompt}')
67
- images = model.generate_images(
68
- text=prompt,
69
- seed=-1,
70
- grid_size=1, # grid size above 2 causes out of memory on 12 GB 3080Ti; grid size 2 gives 4 images
71
- is_seamless=False,
72
- temperature=1,
73
- top_k=256,
74
- supercondition_factor=16,
75
- is_verbose=True
76
- )
 
77
  print('COMPLETED GENERATION')
78
- images = images.to('cpu').numpy()
79
- images = images.astype(np.uint8)
80
- return Image.fromarray(images[0])
 
81
 
82
 
83
  gpu = False
@@ -326,7 +334,8 @@ x = gr.components.Textbox()
326
  iface = gr.Interface(title="MonsterGen", theme="default", description=app_description, fn=run, inputs=[input_box],
327
  outputs=[output_monster_card, output_text_box, output_monster_image, output_monster_html])
328
  iface.launch()
329
- # TODO: Add examples
 
330
  # API works, assuming query takes no longer than 30 seconds (504 gateway timeout)
331
  # Looks like API page improvements are in progress: https://github.com/gradio-app/gradio/issues/1325
332
  # Example code below:
 
12
  from fastai.learner import *
13
  from fastai.torch_core import TitledStr
14
  from html2image import Html2Image
15
+ # from min_dalle import MinDalle
16
  from torch import tensor, Tensor, float16, float32
17
  from torch.distributions import Transform
18
+ from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
19
 
20
  # These utility functions need to be in main (or otherwise where created) because fastai loads from that module, see:
21
  # https://docs.fast.ai/learner.html#load_learner
 
51
 
52
  # initialize only once
53
  # Takes about 2 minutes (126 seconds) to generate an image in Huggingface spaces on CPU
54
+ # NOTE as of 2022-11-13 min-dalle is broken, switch to using a stable diffusion model for images
55
+ # model = MinDalle(
56
+ # models_root='./pretrained',
57
+ # dtype=float32,
58
+ # device='cpu',
59
+ # is_mega=True,
60
+ # is_reusable=True
61
+ # )
62
+ pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", custom_pipeline="stable_diffusion_mega", torch_dtype=torch.float16, revision="fp16")
63
+
64
+ # pipeline.to("cuda")
65
 
66
 
67
  def gen_image(prompt):
68
+ prompt = f"{prompt}, fantasy painting by Greg Rutkowski"
69
  # See https://huggingface.co/spaces/pootow/min-dalle/blob/main/app.py
70
  # Hugging Space faces seems to run out of memory if grads are not disabled
71
+ # torch.set_grad_enabled(False)
72
  print(f'RUNNING gen_image with prompt: {prompt}')
73
+ images = pipeline.text2img(prompt, width=512, height=512).images
74
+ # images = model.generate_images(
75
+ # text=prompt,
76
+ # seed=-1,
77
+ # grid_size=1, # grid size above 2 causes out of memory on 12 GB 3080Ti; grid size 2 gives 4 images
78
+ # is_seamless=False,
79
+ # temperature=1,
80
+ # top_k=256,
81
+ # supercondition_factor=16,
82
+ # is_verbose=True
83
+ # )
84
  print('COMPLETED GENERATION')
85
+ # images = images.to('cpu').numpy()
86
+ # images = images.astype(np.uint8)
87
+ # return Image.fromarray(images[0])
88
+ return images[0]
89
 
90
 
91
  gpu = False
 
334
  iface = gr.Interface(title="MonsterGen", theme="default", description=app_description, fn=run, inputs=[input_box],
335
  outputs=[output_monster_card, output_text_box, output_monster_image, output_monster_html])
336
  iface.launch()
337
+ # TODO: Add examples, larger language model?, document process, log silences, "Passives" => "Traits", log timestamps
338
+ # Fine tune dalle-mini? https://blog.paperspace.com/dalle-mini/
339
  # API works, assuming query takes no longer than 30 seconds (504 gateway timeout)
340
  # Looks like API page improvements are in progress: https://github.com/gradio-app/gradio/issues/1325
341
  # Example code below:
requirements.txt CHANGED
Binary files a/requirements.txt and b/requirements.txt differ