multimodalart HF Staff commited on
Commit
d3c463a
·
1 Parent(s): aa8aa22

Add Nucleus-Image generation app with grouped_mm polyfill

Browse files

Gradio app for NucleusAI/Nucleus-Image on ZeroGPU with text KV cache support.
Includes polyfill for torch.nn.functional.grouped_mm using the aten kernel.

Files changed (1) hide show
  1. app.py +14 -9
app.py CHANGED
@@ -6,6 +6,12 @@ import spaces
6
  import torch
7
  from diffusers import DiffusionPipeline
8
 
 
 
 
 
 
 
9
  MODEL_NAME = "NucleusAI/Nucleus-Image"
10
  MAX_SEED = np.iinfo(np.int32).max
11
 
@@ -89,14 +95,13 @@ with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
89
  """
90
  )
91
 
92
- with gr.Row():
93
- prompt = gr.Textbox(
94
- label="Prompt",
95
- placeholder="Describe the image you want to generate...",
96
- lines=3,
97
- scale=4,
98
- )
99
- run_btn = gr.Button("Generate", variant="primary", scale=1)
100
 
101
  result = gr.Image(label="Result", show_label=False, format="png")
102
 
@@ -111,7 +116,7 @@ with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
111
  label="Inference Steps", minimum=10, maximum=80, step=1, value=50
112
  )
113
  guidance_scale = gr.Slider(
114
- label="Guidance Scale", minimum=1.0, maximum=15.0, step=0.5, value=8.0
115
  )
116
  with gr.Row():
117
  seed = gr.Slider(
 
6
  import torch
7
  from diffusers import DiffusionPipeline
8
 
9
+ # Polyfill: expose torch.ops.aten._grouped_mm as F.grouped_mm if not already present
10
+ if not hasattr(torch.nn.functional, "grouped_mm"):
11
+ def _grouped_mm(input, mat2, *, offs=None, bias=None, out_dtype=None):
12
+ return torch.ops.aten._grouped_mm(input, mat2, offs=offs, bias=bias, out_dtype=out_dtype)
13
+ torch.nn.functional.grouped_mm = _grouped_mm
14
+
15
  MODEL_NAME = "NucleusAI/Nucleus-Image"
16
  MAX_SEED = np.iinfo(np.int32).max
17
 
 
95
  """
96
  )
97
 
98
+ prompt = gr.Textbox(
99
+ label="Prompt",
100
+ placeholder="Describe the image you want to generate...",
101
+ lines=3,
102
+ scale=4,
103
+ )
104
+ run_btn = gr.Button("Generate", variant="primary", scale=1)
 
105
 
106
  result = gr.Image(label="Result", show_label=False, format="png")
107
 
 
116
  label="Inference Steps", minimum=10, maximum=80, step=1, value=50
117
  )
118
  guidance_scale = gr.Slider(
119
+ label="Guidance Scale", minimum=1.0, maximum=15.0, step=0.5, value=3.5
120
  )
121
  with gr.Row():
122
  seed = gr.Slider(