1inkusFace commited on
Commit
a0eb807
·
verified ·
1 Parent(s): 7ea2c56

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -76
app.py CHANGED
@@ -47,46 +47,6 @@ from image_gen_aux import UpscaleWithModel
47
 
48
  from diffusers.models.attention_processor import AttnProcessor2_0
49
  from diffusers.models.attention_processor import Attention
50
- from kernels import get_kernel
51
- vllm_flash_attn3 = get_kernel("kernels-community/vllm-flash-attn3")
52
-
53
- class FlashAttentionProcessor(AttnProcessor2_0):
54
- """
55
- A custom attention processor that uses a pre-compiled Flash Attention 3 kernel.
56
- It inherits from AttnProcessor2_0, which is compatible with PyTorch 2.x attention.
57
- """
58
- def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, **kwargs):
59
- # The 'attn' argument is the parent Attention module, giving access to its parameters.
60
- # The implementation from the kernels library expects query, key, and value in a
61
- # specific format (Batch, Sequence, Heads, Dim_Head), so we must reshape accordingly.
62
-
63
- query = attn.to_q(hidden_states)
64
- encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
65
- key = attn.to_k(encoder_hidden_states)
66
- value = attn.to_v(encoder_hidden_states)
67
-
68
- scale = attn.scale
69
- query = query * scale
70
-
71
- b, t, c = query.shape
72
- h = attn.heads
73
- d = c // h
74
-
75
- # Reshape to (Batch, Heads, Sequence, Dim_Head) for the Flash Attention kernel
76
- q_reshaped = query.reshape(b, t, h, d).permute(0, 2, 1, 3)
77
- k_reshaped = key.reshape(b, t, h, d).permute(0, 2, 1, 3)
78
- v_reshaped = value.reshape(b, t, h, d).permute(0, 2, 1, 3)
79
- out_reshaped = torch.empty_like(q_reshaped)
80
-
81
- # Call the pre-compiled kernel
82
- vllm_flash_attn3.attention(q_reshaped, k_reshaped, v_reshaped, out_reshaped)
83
-
84
- # Reshape output back to (Batch, Sequence, Heads * Dim_Head)
85
- out = out_reshaped.permute(0, 2, 1, 3).reshape(b, t, c)
86
-
87
- out = attn.to_out(out)
88
- return out
89
-
90
 
91
 
92
  # --- GCS Configuration ---
@@ -123,49 +83,15 @@ def upload_to_gcs(image_object, filename):
123
 
124
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
125
 
126
- import torch.export
127
-
128
  @spaces.GPU(duration=120)
129
  def compile_transformer():
130
  with spaces.aoti_capture(pipe.transformer) as call:
131
- # This run captures the structure of the inputs in call.args and call.kwargs
132
- pipe(
133
- "A majestic, ancient Egyptian Sphinx stands sentinel in a large, clear pool under a bright, golden desert sun. Around its weathered stone base, several sleek, playful dolphins gracefully navigate the turquoise waters. The surrounding environment features lush, exotic papyrus plants and distant pyramids under a cloudless sky, conveying a sense of timeless wonder and serene majesty."
134
- )
135
-
136
- # --- START OF CHANGE ---
137
-
138
- dynamic_shapes = {
139
- # Give the two different sequence lengths unique names
140
- "hidden_states": {
141
- 0: torch.export.Dim("batch_size"),
142
- 1: torch.export.Dim("image_sequence_length"), # <-- Unique name
143
- },
144
- "encoder_hidden_states": {
145
- 0: torch.export.Dim("batch_size"),
146
- 1: torch.export.Dim("text_sequence_length"), # <-- Unique name
147
- },
148
-
149
- # The rest remains the same
150
- "pooled_projections": {
151
- 0: torch.export.Dim("batch_size"),
152
- },
153
- "timestep": {
154
- 0: torch.export.Dim("batch_size"),
155
- },
156
- "joint_attention_kwargs": None,
157
- "return_dict": None,
158
- }
159
-
160
- # --- END OF CHANGE ---
161
-
162
  exported = torch.export.export(
163
  pipe.transformer,
164
  args=call.args,
165
  kwargs=call.kwargs,
166
- dynamic_shapes=dynamic_shapes,
167
  )
168
-
169
  return spaces.aoti_compile(exported)
170
 
171
  def load_model():
@@ -185,7 +111,7 @@ def load_model():
185
  upscaler_2 = UpscaleWithModel.from_pretrained("Kim2091/ClearRealityV1").to(device)
186
  return pipe, upscaler_2
187
 
188
- #fa_processor = FlashAttentionProcessor()
189
 
190
  pipe, upscaler_2 = load_model()
191
 
 
47
 
48
  from diffusers.models.attention_processor import AttnProcessor2_0
49
  from diffusers.models.attention_processor import Attention
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
 
52
  # --- GCS Configuration ---
 
83
 
84
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
85
 
 
 
86
  @spaces.GPU(duration=120)
87
  def compile_transformer():
88
  with spaces.aoti_capture(pipe.transformer) as call:
89
+ pipe("A majestic, ancient Egyptian Sphinx stands sentinel in a large, clear pool under a bright, golden desert sun. Around its weathered stone base, several sleek, playful dolphins gracefully navigate the turquoise waters. The surrounding environment features lush, exotic papyrus plants and distant pyramids under a cloudless sky, conveying a sense of timeless wonder and serene majesty.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  exported = torch.export.export(
91
  pipe.transformer,
92
  args=call.args,
93
  kwargs=call.kwargs,
 
94
  )
 
95
  return spaces.aoti_compile(exported)
96
 
97
  def load_model():
 
111
  upscaler_2 = UpscaleWithModel.from_pretrained("Kim2091/ClearRealityV1").to(device)
112
  return pipe, upscaler_2
113
 
114
+ fa_processor = FlashAttentionProcessor()
115
 
116
  pipe, upscaler_2 = load_model()
117