Upload utils.py
Browse files
utils.py
CHANGED
|
@@ -207,6 +207,7 @@ def evaluate_model(model, test_loader, device):
|
|
| 207 |
|
| 208 |
# Visualization
|
| 209 |
import matplotlib.cm as cm
|
|
|
|
| 210 |
def plot_metrics(test_f1_scores, input_types, n_train=None, flag=0):
|
| 211 |
"""
|
| 212 |
Plots the F1-score over epochs or number of training samples.
|
|
@@ -222,11 +223,14 @@ def plot_metrics(test_f1_scores, input_types, n_train=None, flag=0):
|
|
| 222 |
markers = ['o', 's', 'D', '^', 'v', 'P', '*', 'X', 'h'] # Different markers for curves
|
| 223 |
|
| 224 |
for r in range(test_f1_scores.shape[0]):
|
| 225 |
-
color = colors(r / (test_f1_scores.shape[0] - 1)) # Normalize color index
|
| 226 |
marker = markers[r % len(markers)] # Cycle through markers
|
| 227 |
if flag == 0:
|
| 228 |
-
|
| 229 |
-
|
|
|
|
|
|
|
|
|
|
| 230 |
else:
|
| 231 |
plt.plot(n_train, test_f1_scores[r], linewidth=2, marker=marker, markersize=6, markeredgewidth=1.5,
|
| 232 |
markeredgecolor=color, markerfacecolor='none', color=color, label=f"{input_types[r]}")
|
|
|
|
| 207 |
|
| 208 |
# Visualization
|
| 209 |
import matplotlib.cm as cm
|
| 210 |
+
|
| 211 |
def plot_metrics(test_f1_scores, input_types, n_train=None, flag=0):
|
| 212 |
"""
|
| 213 |
Plots the F1-score over epochs or number of training samples.
|
|
|
|
| 223 |
markers = ['o', 's', 'D', '^', 'v', 'P', '*', 'X', 'h'] # Different markers for curves
|
| 224 |
|
| 225 |
for r in range(test_f1_scores.shape[0]):
|
| 226 |
+
color = colors(0.5 if test_f1_scores.shape[0] == 1 else r / (test_f1_scores.shape[0] - 1)) # Normalize color index
|
| 227 |
marker = markers[r % len(markers)] # Cycle through markers
|
| 228 |
if flag == 0:
|
| 229 |
+
if test_f1_scores.shape[0] == 1:
|
| 230 |
+
plt.plot(test_f1_scores[r], linewidth=2, color=color, label=f"{input_types[r]}")
|
| 231 |
+
else:
|
| 232 |
+
plt.plot(test_f1_scores[r], linewidth=2, marker=marker, markersize=5, markeredgewidth=1.5,
|
| 233 |
+
markeredgecolor=color, color=color, label=f"{input_types[r]}")
|
| 234 |
else:
|
| 235 |
plt.plot(n_train, test_f1_scores[r], linewidth=2, marker=marker, markersize=6, markeredgewidth=1.5,
|
| 236 |
markeredgecolor=color, markerfacecolor='none', color=color, label=f"{input_types[r]}")
|