| import os | |
| import json | |
| import pandas as pd | |
| import yaml | |
| import seaborn as sns | |
| import matplotlib.pyplot as plt | |
| from inference import get_latest_checkpoint | |
| def process_loss(loss, final_loss): | |
| epoch = int(loss["epoch"]) | |
| final_loss["epoch"].append(epoch) | |
| for key in ["loss", "eval_loss", "eval_rouge1", "eval_rouge2"]: | |
| try: | |
| value = loss[key] | |
| final_loss[key].append(value) | |
| except KeyError: | |
| pass | |
| def loss_function(losses): | |
| final_loss = { | |
| "epoch": [], | |
| "loss": [], | |
| "eval_loss": [], | |
| "eval_rouge1": [], | |
| "eval_rouge2": [] | |
| } | |
| for loss_steps in losses: | |
| if float(loss_steps.get("epoch", 0)) % 1 == 0: | |
| process_loss(loss_steps, final_loss) | |
| final_loss["epoch"] = list(set(final_loss["epoch"])) | |
| return final_loss | |
| def plot_loss(data, output_dir): | |
| df = pd.DataFrame(data) | |
| df_melted = pd.melt(df, id_vars=['epoch'], var_name='metric', value_name='value') | |
| plt.figure(figsize=(10, 6)) | |
| sns.lineplot(data=df_melted, x='epoch', y='value', hue='metric', marker='o') | |
| plt.legend(title='Metric') | |
| plt.xlabel('Epoch') | |
| plt.ylabel('Value') | |
| plt.title('Metrics vs Epoch') | |
| plt.savefig(os.path.join(output_dir, 'metrics_vs_epoch.png')) | |
| if __name__ == "__main__": | |
| config = yaml.safe_load(open("config.yaml", "r")) | |
| PROJECT_DIR = eval(config["SENTENCE_COMPRESSION"]["PROJECT_DIR"]) | |
| checkpoint_dir = config["SENTENCE_COMPRESSION"]["INFERENCE"]["MODEL_PATH"] | |
| latest_checkpoint = get_latest_checkpoint(os.path.join(PROJECT_DIR, checkpoint_dir)) | |
| logfile_dir = os.path.join(PROJECT_DIR, checkpoint_dir, latest_checkpoint) | |
| logfile_path = os.path.join(logfile_dir, "trainer_state.json") | |
| logs = json.load(open(logfile_path)) | |
| final_loss = loss_function(logs["log_history"]) | |
| output_dir = config["SENTENCE_COMPRESSION"]["OUTPUT"]["RESULT"] | |
| os.makedirs(output_dir, exist_ok=True) | |
| plot_loss(final_loss, output_dir) |