| import os |
| import json |
|
|
| def create_graph(lora_path, lora_name): |
| try: |
| import matplotlib.pyplot as plt |
| from matplotlib.ticker import ScalarFormatter |
| |
| peft_model_path = f'{lora_path}/training_graph.json' |
| image_model_path = f'{lora_path}/training_graph.png' |
| |
| if os.path.exists(peft_model_path): |
| |
| with open(peft_model_path, 'r') as file: |
| data = json.load(file) |
| |
| x = [item['epoch'] for item in data] |
| y1 = [item['learning_rate'] for item in data] |
| y2 = [item['loss'] for item in data] |
|
|
| |
| fig, ax1 = plt.subplots(figsize=(10, 6)) |
| |
|
|
| |
| ax1.plot(x, y1, 'b-', label='Learning Rate') |
| ax1.set_xlabel('Epoch') |
| ax1.set_ylabel('Learning Rate', color='b') |
| ax1.tick_params('y', colors='b') |
|
|
| |
| ax2 = ax1.twinx() |
|
|
| |
| ax2.plot(x, y2, 'r-', label='Loss') |
| ax2.set_ylabel('Loss', color='r') |
| ax2.tick_params('y', colors='r') |
|
|
| |
| ax1.yaxis.set_major_formatter(ScalarFormatter(useMathText=True)) |
| ax1.ticklabel_format(style='sci', axis='y', scilimits=(0,0)) |
|
|
| |
| ax1.grid(True) |
|
|
| |
| lines, labels = ax1.get_legend_handles_labels() |
| lines2, labels2 = ax2.get_legend_handles_labels() |
| ax2.legend(lines + lines2, labels + labels2, loc='best') |
|
|
| |
| plt.title(f'{lora_name} LR and Loss vs Epoch') |
|
|
| |
| plt.savefig(image_model_path) |
|
|
| print(f"Graph saved in {image_model_path}") |
| else: |
| print(f"File 'training_graph.json' does not exist in the {lora_path}") |
| |
| except ImportError: |
| print("matplotlib is not installed. Please install matplotlib to create PNG graphs") |