aromidvar1355 commited on
Commit
7ab37b4
·
verified ·
1 Parent(s): b4debce

Update core/plot.py

Browse files
Files changed (1) hide show
  1. core/plot.py +13 -0
core/plot.py CHANGED
@@ -26,4 +26,17 @@ def plot_metrics(result):
26
  metrics = result['metrics']
27
  sns.barplot(x=list(metrics.keys()), y=list(metrics.values()), ax=ax)
28
  ax.set_title("Error Metrics")
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  return fig
 
26
  metrics = result['metrics']
27
  sns.barplot(x=list(metrics.keys()), y=list(metrics.values()), ax=ax)
28
  ax.set_title("Error Metrics")
29
+ return fig
30
+
31
+ def plot_loss_curve(result):
32
+ train_losses = result['train_loss']
33
+ val_losses = result.get('val_loss', [])
34
+ fig, ax = plt.subplots()
35
+ ax.plot(train_losses, label='Train Loss')
36
+ if val_losses:
37
+ ax.plot(val_losses, label='Validation Loss')
38
+ ax.set_title('Loss Curve')
39
+ ax.set_xlabel('Epoch')
40
+ ax.set_ylabel('Loss')
41
+ ax.legend()
42
  return fig