| import os
|
| import json
|
|
|
| def create_graph(lora_path, lora_name):
|
| try:
|
| import matplotlib.pyplot as plt
|
| from matplotlib.ticker import ScalarFormatter
|
|
|
| peft_model_path = f'{lora_path}/training_graph.json'
|
| image_model_path = f'{lora_path}/training_graph.png'
|
|
|
| if os.path.exists(peft_model_path):
|
|
|
| with open(peft_model_path, 'r') as file:
|
| data = json.load(file)
|
|
|
| x = [item['epoch'] for item in data]
|
| y1 = [item['learning_rate'] for item in data]
|
| y2 = [item['loss'] for item in data]
|
|
|
|
|
| fig, ax1 = plt.subplots(figsize=(10, 6))
|
|
|
|
|
|
|
| ax1.plot(x, y1, 'b-', label='Learning Rate')
|
| ax1.set_xlabel('Epoch')
|
| ax1.set_ylabel('Learning Rate', color='b')
|
| ax1.tick_params('y', colors='b')
|
|
|
|
|
| ax2 = ax1.twinx()
|
|
|
|
|
| ax2.plot(x, y2, 'r-', label='Loss')
|
| ax2.set_ylabel('Loss', color='r')
|
| ax2.tick_params('y', colors='r')
|
|
|
|
|
| ax1.yaxis.set_major_formatter(ScalarFormatter(useMathText=True))
|
| ax1.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
|
|
|
|
|
| ax1.grid(True)
|
|
|
|
|
| lines, labels = ax1.get_legend_handles_labels()
|
| lines2, labels2 = ax2.get_legend_handles_labels()
|
| ax2.legend(lines + lines2, labels + labels2, loc='best')
|
|
|
|
|
| plt.title(f'{lora_name} LR and Loss vs Epoch')
|
|
|
|
|
| plt.savefig(image_model_path)
|
|
|
| print(f"Graph saved in {image_model_path}")
|
| else:
|
| print(f"File 'training_graph.json' does not exist in the {lora_path}")
|
|
|
| except ImportError:
|
| print("matplotlib is not installed. Please install matplotlib to create PNG graphs") |