Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -181,28 +181,30 @@ def benchmark(Batch, Heads, N, D_head, provider):
|
|
| 181 |
return tflops(ms), tflops(max_ms), tflops(min_ms)
|
| 182 |
|
| 183 |
|
| 184 |
-
|
| 185 |
-
@spaces.GPU(duration=
|
| 186 |
-
def
|
| 187 |
-
#
|
| 188 |
-
|
| 189 |
-
if not os.path.exists(
|
| 190 |
-
os.makedirs(
|
| 191 |
-
|
| 192 |
-
# Run
|
| 193 |
-
#
|
| 194 |
-
bench_flash_attention.run(save_path=
|
| 195 |
|
| 196 |
-
#
|
| 197 |
-
images = [os.path.join(
|
| 198 |
return images
|
| 199 |
|
| 200 |
-
#
|
| 201 |
with gr.Blocks() as demo:
|
| 202 |
-
gr.Markdown("# Triton
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
|
|
|
|
|
|
| 207 |
|
| 208 |
demo.launch()
|
|
|
|
| 181 |
return tflops(ms), tflops(max_ms), tflops(min_ms)
|
| 182 |
|
| 183 |
|
| 184 |
+
# 2. --- WRAP THE RUN COMMAND IN A DECORATED FUNCTION ---
|
| 185 |
+
@spaces.GPU(duration=150) # High duration for Triton compilation + Benchmarking
|
| 186 |
+
def start_benchmarking():
|
| 187 |
+
# Triton saves plots to the current directory by default
|
| 188 |
+
save_path = "./plots"
|
| 189 |
+
if not os.path.exists(save_path):
|
| 190 |
+
os.makedirs(save_path)
|
| 191 |
+
|
| 192 |
+
# Run your original benchmark function
|
| 193 |
+
# Note: Ensure bench_flash_attention is defined above this
|
| 194 |
+
bench_flash_attention.run(save_path=save_path, print_data=True)
|
| 195 |
|
| 196 |
+
# Find the .png files generated by Triton
|
| 197 |
+
images = [os.path.join(save_path, f) for f in os.listdir(save_path) if f.endswith('.png')]
|
| 198 |
return images
|
| 199 |
|
| 200 |
+
# 3. --- CREATE THE GRADIO GUI TO KEEP THE SPACE ALIVE ---
|
| 201 |
with gr.Blocks() as demo:
|
| 202 |
+
gr.Markdown("# Triton Attention Benchmark")
|
| 203 |
+
gr.Markdown("Click the button below to trigger the ZeroGPU and run the Triton benchmark.")
|
| 204 |
+
|
| 205 |
+
run_btn = gr.Button("Run Benchmark (H100/H200)", variant="primary")
|
| 206 |
+
plot_gallery = gr.Gallery(label="Generated Performance Plots", columns=2)
|
| 207 |
+
|
| 208 |
+
run_btn.click(fn=start_benchmarking, outputs=plot_gallery)
|
| 209 |
|
| 210 |
demo.launch()
|