Spaces:
Sleeping
Sleeping
Update core/plot.py
Browse files- core/plot.py +13 -0
core/plot.py
CHANGED
|
@@ -26,4 +26,17 @@ def plot_metrics(result):
|
|
| 26 |
metrics = result['metrics']
|
| 27 |
sns.barplot(x=list(metrics.keys()), y=list(metrics.values()), ax=ax)
|
| 28 |
ax.set_title("Error Metrics")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
return fig
|
|
|
|
| 26 |
metrics = result['metrics']
|
| 27 |
sns.barplot(x=list(metrics.keys()), y=list(metrics.values()), ax=ax)
|
| 28 |
ax.set_title("Error Metrics")
|
| 29 |
+
return fig
|
| 30 |
+
|
| 31 |
+
def plot_loss_curve(result):
|
| 32 |
+
train_losses = result['train_loss']
|
| 33 |
+
val_losses = result.get('val_loss', [])
|
| 34 |
+
fig, ax = plt.subplots()
|
| 35 |
+
ax.plot(train_losses, label='Train Loss')
|
| 36 |
+
if val_losses:
|
| 37 |
+
ax.plot(val_losses, label='Validation Loss')
|
| 38 |
+
ax.set_title('Loss Curve')
|
| 39 |
+
ax.set_xlabel('Epoch')
|
| 40 |
+
ax.set_ylabel('Loss')
|
| 41 |
+
ax.legend()
|
| 42 |
return fig
|