trinath3 commited on
Commit
8d83efa
·
verified ·
1 Parent(s): 932ccfc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -19
app.py CHANGED
@@ -181,28 +181,30 @@ def benchmark(Batch, Heads, N, D_head, provider):
181
  return tflops(ms), tflops(max_ms), tflops(min_ms)
182
 
183
 
184
-
185
- @spaces.GPU(duration=180) # Triton benchmarks can take a minute
186
- def run_benchmark():
187
- # Ensure we are in a clean directory for images
188
- output_dir = "./plots"
189
- if not os.path.exists(output_dir):
190
- os.makedirs(output_dir)
191
-
192
- # Run the triton benchmark
193
- # This will generate several .png files in the save_path
194
- bench_flash_attention.run(save_path=output_dir, print_data=True)
195
 
196
- # Collect the generated images
197
- images = [os.path.join(output_dir, f) for f in os.listdir(output_dir) if f.endswith('.png')]
198
  return images
199
 
200
- # Gradio Interface
201
  with gr.Blocks() as demo:
202
- gr.Markdown("# Triton Fused Attention Benchmark on ZeroGPU")
203
- run_btn = gr.Button("Run Benchmark")
204
- out_gallery = gr.Gallery(label="Performance Plots")
205
-
206
- run_btn.click(fn=run_benchmark, outputs=out_gallery)
 
 
207
 
208
  demo.launch()
 
181
  return tflops(ms), tflops(max_ms), tflops(min_ms)
182
 
183
 
184
+ # 2. --- WRAP THE RUN COMMAND IN A DECORATED FUNCTION ---
185
+ @spaces.GPU(duration=150) # High duration for Triton compilation + Benchmarking
186
+ def start_benchmarking():
187
+ # Triton saves plots to the current directory by default
188
+ save_path = "./plots"
189
+ if not os.path.exists(save_path):
190
+ os.makedirs(save_path)
191
+
192
+ # Run your original benchmark function
193
+ # Note: Ensure bench_flash_attention is defined above this
194
+ bench_flash_attention.run(save_path=save_path, print_data=True)
195
 
196
+ # Find the .png files generated by Triton
197
+ images = [os.path.join(save_path, f) for f in os.listdir(save_path) if f.endswith('.png')]
198
  return images
199
 
200
+ # 3. --- CREATE THE GRADIO GUI TO KEEP THE SPACE ALIVE ---
201
  with gr.Blocks() as demo:
202
+ gr.Markdown("# Triton Attention Benchmark")
203
+ gr.Markdown("Click the button below to trigger the ZeroGPU and run the Triton benchmark.")
204
+
205
+ run_btn = gr.Button("Run Benchmark (H100/H200)", variant="primary")
206
+ plot_gallery = gr.Gallery(label="Generated Performance Plots", columns=2)
207
+
208
+ run_btn.click(fn=start_benchmarking, outputs=plot_gallery)
209
 
210
  demo.launch()