Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -77,8 +77,10 @@ def plot_pens(tflpos_card, utilization, num_gps, training_days):
|
|
| 77 |
plt.axvline(ns[best], color='red')
|
| 78 |
plt.xlabel('model size')
|
| 79 |
plt.ylabel('loss')
|
|
|
|
|
|
|
| 80 |
|
| 81 |
-
return
|
| 82 |
|
| 83 |
|
| 84 |
if __name__ == "__main__":
|
|
@@ -86,12 +88,12 @@ if __name__ == "__main__":
|
|
| 86 |
fn=plot_pens,
|
| 87 |
inputs=[
|
| 88 |
gr.Textbox(label="TFLOP/s pre Card",value="40"),
|
| 89 |
-
gr.Slider(label="
|
| 90 |
-
gr.Textbox(label="Number of cards"
|
| 91 |
-
gr.Textbox(label="Training Days"
|
| 92 |
],
|
| 93 |
outputs=[
|
| 94 |
-
gr.
|
| 95 |
gr.Label(label="Total Compute Budget"),
|
| 96 |
gr.Label(label="Estimated Final Loss"),
|
| 97 |
gr.Label(label="Optimal Model Size"),
|
|
@@ -100,5 +102,5 @@ if __name__ == "__main__":
|
|
| 100 |
title="Compute-Optimal Model Estimator",
|
| 101 |
description=description,
|
| 102 |
article=article,
|
| 103 |
-
live=
|
| 104 |
).launch()
|
|
|
|
| 77 |
plt.axvline(ns[best], color='red')
|
| 78 |
plt.xlabel('model size')
|
| 79 |
plt.ylabel('loss')
|
| 80 |
+
fig.savefig("/tmp/tmp.jpg")
|
| 81 |
+
plt.close()
|
| 82 |
|
| 83 |
+
return "/tmp/tmp.jpg", c, round(losses[best], 3), best_model_size, best_dataset_size
|
| 84 |
|
| 85 |
|
| 86 |
if __name__ == "__main__":
|
|
|
|
| 88 |
fn=plot_pens,
|
| 89 |
inputs=[
|
| 90 |
gr.Textbox(label="TFLOP/s pre Card",value="40"),
|
| 91 |
+
gr.Slider(label="GPU Utilization", minimum=0, maximum=1, step=0.01,value=0.25),
|
| 92 |
+
gr.Textbox(label="Number of cards"),
|
| 93 |
+
gr.Textbox(label="Training Days")
|
| 94 |
],
|
| 95 |
outputs=[
|
| 96 |
+
gr.Image(label="Estimated Loss"),
|
| 97 |
gr.Label(label="Total Compute Budget"),
|
| 98 |
gr.Label(label="Estimated Final Loss"),
|
| 99 |
gr.Label(label="Optimal Model Size"),
|
|
|
|
| 102 |
title="Compute-Optimal Model Estimator",
|
| 103 |
description=description,
|
| 104 |
article=article,
|
| 105 |
+
live=True
|
| 106 |
).launch()
|