proteinMPNN_noMHC / plot_curves.py
smares's picture
Upload plot_curves.py with huggingface_hub
267c911 verified
Raw
History Blame Contribute Delete
1.53 kB
#!/usr/bin/env python
import re, matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
ep, tr, va, tra, vaa = [], [], [], [], []
pat = re.compile(r"epoch:\s*(\d+).*?train:\s*([\d.]+),\s*valid:\s*([\d.]+),\s*train_acc:\s*([\d.]+),\s*valid_acc:\s*([\d.]+)")
for line in open("/global/scratch/users/sergiomar10/TCera/ProteinMPNN/hf_repo/_log_snapshot.txt"):
m = pat.search(line)
if m:
ep.append(int(m.group(1))); tr.append(float(m.group(2))); va.append(float(m.group(3)))
tra.append(float(m.group(4))); vaa.append(float(m.group(5)))
print(f"parsed {len(ep)} epochs (last={ep[-1] if ep else None})")
# Figure 1: loss (perplexity)
plt.figure(figsize=(7,4.5))
plt.plot(ep, tr, label="train", lw=2)
plt.plot(ep, va, label="validation", lw=2)
plt.xlabel("epoch"); plt.ylabel("perplexity (loss)")
plt.title("ProteinMPNN-noMHC — training / validation loss")
plt.legend(); plt.grid(alpha=0.3); plt.tight_layout()
plt.savefig("/global/scratch/users/sergiomar10/TCera/ProteinMPNN/hf_repo/loss_curve.png", dpi=140)
# Figure 2: accuracy (sequence recovery)
plt.figure(figsize=(7,4.5))
plt.plot(ep, tra, label="train", lw=2)
plt.plot(ep, vaa, label="validation", lw=2)
plt.xlabel("epoch"); plt.ylabel("sequence-recovery accuracy")
plt.title("ProteinMPNN-noMHC — training / validation accuracy")
plt.legend(); plt.grid(alpha=0.3); plt.tight_layout()
plt.savefig("/global/scratch/users/sergiomar10/TCera/ProteinMPNN/hf_repo/accuracy_curve.png", dpi=140)
print("wrote loss_curve.png and accuracy_curve.png")