|
|
import matplotlib.pyplot as plt
|
|
|
import re
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
plt.rcParams['font.sans-serif'] = ['SimHei']
|
|
|
plt.rcParams['axes.unicode_minus'] = False
|
|
|
|
|
|
|
|
|
def parse_log_file(file_path):
|
|
|
epochs = []
|
|
|
losses = []
|
|
|
|
|
|
with open(file_path, 'r', encoding='utf-8') as f:
|
|
|
for line in f:
|
|
|
|
|
|
match = re.search(
|
|
|
r'Epoch:\s+(\d+)/*.*Loss:\s+(\d+\.\d+)',
|
|
|
line.strip()
|
|
|
)
|
|
|
if match:
|
|
|
epoch = int(match.group(1))
|
|
|
loss = float(match.group(2))
|
|
|
epochs.append(epoch)
|
|
|
losses.append(loss)
|
|
|
|
|
|
return epochs, losses
|
|
|
|
|
|
|
|
|
log_file = "E:\\智能语音处理系统\\Noise-suppression-and-speech-recognition-systems-master\\WaveRNNModel\\checkpoints\\ljspeech_lsa_smooth_attention.tacotron\\log_test.txt"
|
|
|
|
|
|
|
|
|
try:
|
|
|
epochs_read, losses = parse_log_file(log_file)
|
|
|
print(epochs_read)
|
|
|
epochs=np.arange(len(epochs_read))
|
|
|
print(epochs)
|
|
|
except FileNotFoundError:
|
|
|
print(f"错误:文件 {log_file} 不存在,请检查路径!")
|
|
|
exit()
|
|
|
except Exception as e:
|
|
|
print(f"解析文件时出错: {str(e)}")
|
|
|
exit()
|
|
|
|
|
|
|
|
|
plt.figure(figsize=(10, 6))
|
|
|
plt.plot(epochs, losses, 'b-', linewidth=2, label='训练损失')
|
|
|
|
|
|
|
|
|
plt.title('训练损失随轮次变化曲线', fontsize=14)
|
|
|
plt.xlabel('训练轮次 (Epoch)', fontsize=12)
|
|
|
plt.ylabel('损失值 (Loss)', fontsize=12)
|
|
|
|
|
|
plt.grid(True, linestyle='--', alpha=0.7)
|
|
|
plt.legend()
|
|
|
|
|
|
|
|
|
min_loss = min(losses)
|
|
|
min_idx = losses.index(min_loss)
|
|
|
plt.annotate(
|
|
|
f'最低损失: {min_loss:.3f}',
|
|
|
xy=(epochs[min_idx], min_loss),
|
|
|
xytext=(epochs[min_idx]-3, min_loss+0.1),
|
|
|
arrowprops=dict(arrowstyle='->', color='red'),
|
|
|
fontsize=10,
|
|
|
color='red'
|
|
|
)
|
|
|
|
|
|
plt.tight_layout()
|
|
|
plt.show() |