Upload 30 files
Browse files- .gitattributes +1 -0
- README.txt +6 -0
- algae-filelist.txt +3 -0
- ckpt.pt +3 -0
- contam-filelist.txt +12 -0
- filelist.txt +1 -0
- generated_prompts_algae1.txt_headed.fa +0 -0
- generated_prompts_algae2.txt_headed.fa +0 -0
- generated_prompts_algae3.txt_headed.fa +0 -0
- generated_prompts_archa1.txt_headed.fa +0 -0
- generated_prompts_archa2.txt_headed.fa +0 -0
- generated_prompts_archa3.txt_headed.fa +0 -0
- generated_prompts_bact1.txt_headed.fa +0 -0
- generated_prompts_bact2.txt_headed.fa +0 -0
- generated_prompts_bact3.txt_headed.fa +0 -0
- generated_prompts_fungi1.txt_headed.fa +0 -0
- generated_prompts_fungi2.txt_headed.fa +0 -0
- generated_prompts_fungi3.txt_headed.fa +0 -0
- generated_prompts_virus1.txt_headed.fa +0 -0
- generated_prompts_virus2.txt_headed.fa +0 -0
- generated_prompts_virus3.txt_headed.fa +0 -0
- infer_TI-inc-algaGPT.py +179 -0
- la4sr_sp2.sif +3 -0
- la4sr_sp2.sif.md5 +1 -0
- llm-metrics-two-files.py +475 -0
- meta.pkl +3 -0
- model.py +330 -0
- run_la4sr_TI-inc-algaGPT.sh +144 -0
- run_la4sr_loop.sbatch +32 -0
- slurm-10718799.out +193 -0
- targz.sh +9 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
la4sr_sp2.sif filter=lfs diff=lfs merge=lfs -text
|
README.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
to run, list your fasta files in the filelist.txt files and submit the .sbatch script, or just run run_la4sr_TI-inc-algaGPT.sh if no scheduler is available
|
| 2 |
+
|
| 3 |
+
alternatively, you can run the inference script (for raw outputs from next-token generation) and the model metrics script seperately
|
| 4 |
+
|
| 5 |
+
expected outputs from default run are in results-archive
|
| 6 |
+
|
algae-filelist.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
generated_prompts_algae1.txt_headed.fa
|
| 2 |
+
generated_prompts_algae2.txt_headed.fa
|
| 3 |
+
generated_prompts_algae3.txt_headed.fa
|
ckpt.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c95eb4fa7488a9f311deddda7a687d261bcc17169b76eb75dca856587b959a67
|
| 3 |
+
size 1037880346
|
contam-filelist.txt
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
generated_prompts_archa1.txt_headed.fa
|
| 2 |
+
generated_prompts_archa2.txt_headed.fa
|
| 3 |
+
generated_prompts_archa3.txt_headed.fa
|
| 4 |
+
generated_prompts_bact1.txt_headed.fa
|
| 5 |
+
generated_prompts_bact2.txt_headed.fa
|
| 6 |
+
generated_prompts_bact3.txt_headed.fa
|
| 7 |
+
generated_prompts_fungi1.txt_headed.fa
|
| 8 |
+
generated_prompts_fungi2.txt_headed.fa
|
| 9 |
+
generated_prompts_fungi3.txt_headed.fa
|
| 10 |
+
generated_prompts_virus1.txt_headed.fa
|
| 11 |
+
generated_prompts_virus2.txt_headed.fa
|
| 12 |
+
generated_prompts_virus3.txt_headed.fa
|
filelist.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
chlorophyta_chloroplast_proteins.fasta
|
generated_prompts_algae1.txt_headed.fa
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
generated_prompts_algae2.txt_headed.fa
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
generated_prompts_algae3.txt_headed.fa
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
generated_prompts_archa1.txt_headed.fa
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
generated_prompts_archa2.txt_headed.fa
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
generated_prompts_archa3.txt_headed.fa
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
generated_prompts_bact1.txt_headed.fa
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
generated_prompts_bact2.txt_headed.fa
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
generated_prompts_bact3.txt_headed.fa
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
generated_prompts_fungi1.txt_headed.fa
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
generated_prompts_fungi2.txt_headed.fa
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
generated_prompts_fungi3.txt_headed.fa
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
generated_prompts_virus1.txt_headed.fa
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
generated_prompts_virus2.txt_headed.fa
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
generated_prompts_virus3.txt_headed.fa
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
infer_TI-inc-algaGPT.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
la4sr_infer_fasta2tsv.py — drop‑in replacement for the older LA4SR
|
| 4 |
+
inference utilities used by **run_la4sr_TI-inc.sh**
|
| 5 |
+
|
| 6 |
+
Changes vs. the legacy “Sample from a trained model” script
|
| 7 |
+
----------------------------------------------------------
|
| 8 |
+
1. **Parses a FASTA file** and feeds the collapsed sequence to the model.
|
| 9 |
+
2. **Generates up to 14 tokens** per record (defaults identical to old code).
|
| 10 |
+
3. **Emits a TSV** with columns: record_id, sequence, model_output.
|
| 11 |
+
4. CLI knobs match the wrapper (temperature, top‑k, etc.) plus `-o`.
|
| 12 |
+
|
| 13 |
+
Python ≥3.6 compatible (removed `from __future__ import annotations`).
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import os, sys, argparse, pickle, random
|
| 17 |
+
# ---------------------------------------------------------------------------
|
| 18 |
+
# Python < 3.7 compatibility: provide a fallback for contextlib.nullcontext
|
| 19 |
+
# ---------------------------------------------------------------------------
|
| 20 |
+
try:
|
| 21 |
+
from contextlib import nullcontext # Python ≥3.7
|
| 22 |
+
except ImportError: # Python 3.6 and older
|
| 23 |
+
class _NullContext:
|
| 24 |
+
def __init__(self, result=None):
|
| 25 |
+
self.result = result
|
| 26 |
+
def __enter__(self):
|
| 27 |
+
return self.result
|
| 28 |
+
def __exit__(self, *exc):
|
| 29 |
+
return False
|
| 30 |
+
nullcontext = _NullContext
|
| 31 |
+
from typing import Iterator, Tuple
|
| 32 |
+
|
| 33 |
+
import torch, tiktoken
|
| 34 |
+
from model import GPTConfig, GPT
|
| 35 |
+
|
| 36 |
+
###############################################################################
|
| 37 |
+
# FASTA reader #
|
| 38 |
+
###############################################################################
|
| 39 |
+
|
| 40 |
+
def stream_fasta(path: str) -> Iterator[Tuple[str, str]]:
|
| 41 |
+
"""Yield (header, sequence) tuples, collapsing wrapped lines."""
|
| 42 |
+
header, seq_chunks = None, []
|
| 43 |
+
with open(path) as fh:
|
| 44 |
+
for line in fh:
|
| 45 |
+
line = line.strip()
|
| 46 |
+
if not line:
|
| 47 |
+
continue
|
| 48 |
+
if line.startswith('>'):
|
| 49 |
+
if header is not None:
|
| 50 |
+
yield header, ''.join(seq_chunks)
|
| 51 |
+
header = line[1:].split()[0]
|
| 52 |
+
seq_chunks = []
|
| 53 |
+
else:
|
| 54 |
+
seq_chunks.append(line)
|
| 55 |
+
if header is not None:
|
| 56 |
+
yield header, ''.join(seq_chunks)
|
| 57 |
+
|
| 58 |
+
###############################################################################
|
| 59 |
+
# argument parsing #
|
| 60 |
+
###############################################################################
|
| 61 |
+
|
| 62 |
+
def get_cli() -> argparse.Namespace:
|
| 63 |
+
p = argparse.ArgumentParser(description="LA4SR FASTA→TSV inference script")
|
| 64 |
+
# model/runtime-----------------------------------------------------------
|
| 65 |
+
p.add_argument('--init_from', default='resume',
|
| 66 |
+
choices=['resume','gpt2','gpt2-medium','gpt2-large'],
|
| 67 |
+
help='Model source; "resume" = local ckpt.pt')
|
| 68 |
+
p.add_argument('--out_dir', default='out',
|
| 69 |
+
help='Directory with ckpt.pt if --init_from resume')
|
| 70 |
+
p.add_argument('--device', default='cuda')
|
| 71 |
+
p.add_argument('--dtype', default='float16',
|
| 72 |
+
choices=['float32','bfloat16','float16'])
|
| 73 |
+
p.add_argument('--seed', type=int, default=1337)
|
| 74 |
+
p.add_argument('--compile', action='store_true')
|
| 75 |
+
# generation knobs--------------------------------------------------------
|
| 76 |
+
p.add_argument('--max_new_tokens', type=int, default=14)
|
| 77 |
+
p.add_argument('--temperature', type=float, default=0.1)
|
| 78 |
+
p.add_argument('--top_k', type=int, default=10)
|
| 79 |
+
# I/O---------------------------------------------------------------------
|
| 80 |
+
p.add_argument('fasta_in', help='Input FASTA')
|
| 81 |
+
p.add_argument('-o','--tsv_out', help='Output TSV (default: out-algaGPT/<basename>.tsv)')
|
| 82 |
+
return p.parse_args()
|
| 83 |
+
|
| 84 |
+
args = get_cli()
|
| 85 |
+
|
| 86 |
+
###############################################################################
|
| 87 |
+
# reproducibility & autocast context #
|
| 88 |
+
###############################################################################
|
| 89 |
+
|
| 90 |
+
torch.manual_seed(args.seed)
|
| 91 |
+
random.seed(args.seed)
|
| 92 |
+
if torch.cuda.is_available():
|
| 93 |
+
torch.cuda.manual_seed_all(args.seed)
|
| 94 |
+
|
| 95 |
+
device_type = 'cuda' if 'cuda' in args.device else 'cpu'
|
| 96 |
+
ptdtype_map = {'float32': torch.float32,
|
| 97 |
+
'bfloat16': torch.bfloat16,
|
| 98 |
+
'float16': torch.float16}
|
| 99 |
+
ptdtype = ptdtype_map[args.dtype]
|
| 100 |
+
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
|
| 101 |
+
|
| 102 |
+
###############################################################################
|
| 103 |
+
# model loading #
|
| 104 |
+
###############################################################################
|
| 105 |
+
|
| 106 |
+
if args.init_from == 'resume':
|
| 107 |
+
ckpt_path = os.path.join(args.out_dir, 'ckpt.pt')
|
| 108 |
+
checkpoint = torch.load(ckpt_path, map_location=args.device)
|
| 109 |
+
gptconf = GPTConfig(**checkpoint['model_args'])
|
| 110 |
+
model = GPT(gptconf)
|
| 111 |
+
# strip DDP prefixes if present
|
| 112 |
+
state_dict = {k.replace('_orig_mod.',''):v for k,v in checkpoint['model'].items()}
|
| 113 |
+
model.load_state_dict(state_dict)
|
| 114 |
+
else:
|
| 115 |
+
model = GPT.from_pretrained(args.init_from, dict(dropout=0.0))
|
| 116 |
+
|
| 117 |
+
model.to(args.device).eval()
|
| 118 |
+
if args.compile:
|
| 119 |
+
model = torch.compile(model)
|
| 120 |
+
|
| 121 |
+
###############################################################################
|
| 122 |
+
# encoding / decoding setup #
|
| 123 |
+
###############################################################################
|
| 124 |
+
|
| 125 |
+
if args.init_from == 'resume' and 'config' in locals().get('checkpoint',{}):
|
| 126 |
+
cfg = checkpoint['config']
|
| 127 |
+
meta_path = os.path.join('data', cfg.get('dataset',''), 'meta.pkl')
|
| 128 |
+
else:
|
| 129 |
+
meta_path = ''
|
| 130 |
+
# ------------------------------------------------------------------
|
| 131 |
+
# Fallback: meta.pkl next to ckpt.pt / in --out_dir
|
| 132 |
+
# ------------------------------------------------------------------
|
| 133 |
+
if (not meta_path or not os.path.exists(meta_path)) and args.out_dir:
|
| 134 |
+
alt_meta = os.path.join(args.out_dir, 'meta.pkl')
|
| 135 |
+
if os.path.exists(alt_meta):
|
| 136 |
+
meta_path = alt_meta
|
| 137 |
+
|
| 138 |
+
if meta_path and os.path.exists(meta_path):
|
| 139 |
+
with open(meta_path,'rb') as f:
|
| 140 |
+
meta = pickle.load(f)
|
| 141 |
+
stoi, itos = meta['stoi'], meta['itos']
|
| 142 |
+
#encode = lambda s: [stoi.get(c, stoi['<unk>']) for c in s]
|
| 143 |
+
UNK_ID = stoi.get('<unk>', 0) # fall back to 0 if not present
|
| 144 |
+
encode = lambda s: [stoi.get(c, UNK_ID) for c in s]
|
| 145 |
+
decode = lambda l: ''.join(itos[i] for i in l)
|
| 146 |
+
|
| 147 |
+
else:
|
| 148 |
+
enc = tiktoken.get_encoding('gpt2')
|
| 149 |
+
encode = lambda s: enc.encode(s, allowed_special={""})
|
| 150 |
+
decode = lambda l: enc.decode(l)
|
| 151 |
+
|
| 152 |
+
###############################################################################
|
| 153 |
+
# output path #
|
| 154 |
+
###############################################################################
|
| 155 |
+
|
| 156 |
+
os.makedirs('out-algaGPT', exist_ok=True)
|
| 157 |
+
outfile = args.tsv_out or os.path.join('out-algaGPT', f"{os.path.splitext(os.path.basename(args.fasta_in))[0]}.tsv")
|
| 158 |
+
|
| 159 |
+
###############################################################################
|
| 160 |
+
# generation loop #
|
| 161 |
+
###############################################################################
|
| 162 |
+
|
| 163 |
+
with open(outfile,'w') as tsv, torch.no_grad(), ctx:
|
| 164 |
+
tsv.write('record_id\tsequence\tmodel_output\n')
|
| 165 |
+
for rid, seq in stream_fasta(args.fasta_in):
|
| 166 |
+
if not seq:
|
| 167 |
+
print(f"[WARN] empty sequence for {rid}; skipping", file=sys.stderr)
|
| 168 |
+
continue
|
| 169 |
+
x = torch.tensor(encode(seq), dtype=torch.long, device=args.device).unsqueeze(0)
|
| 170 |
+
try:
|
| 171 |
+
y = model.generate(x, args.max_new_tokens, temperature=args.temperature, top_k=args.top_k)
|
| 172 |
+
cont = decode(y[0].tolist())
|
| 173 |
+
except Exception as e:
|
| 174 |
+
print(f"[ERR] generation failed on {rid}: {e}", file=sys.stderr)
|
| 175 |
+
cont = ''
|
| 176 |
+
tsv.write(f"{rid}\t{seq}\t{cont}\n")
|
| 177 |
+
|
| 178 |
+
print(f"\n✓ Predictions saved to {outfile}\n")
|
| 179 |
+
|
la4sr_sp2.sif
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5b1004d7425ffc698a13ffc537d8c25fd4073ac383d5b583d01971b4616069b7
|
| 3 |
+
size 7201992704
|
la4sr_sp2.sif.md5
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
0ce6720b829a4eb28d87ded1301da3ca la4sr_sp2.sif
|
llm-metrics-two-files.py
ADDED
|
@@ -0,0 +1,475 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
LLM Classification Metrics Generator for Two-File Analysis
|
| 4 |
+
|
| 5 |
+
This script analyzes the LLM classification results from two separate files:
|
| 6 |
+
- One containing algal sequences (true algal samples)
|
| 7 |
+
- One containing contaminant sequences (true contaminant samples)
|
| 8 |
+
|
| 9 |
+
It extracts the predicted tags and calculates comprehensive metrics.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import re
|
| 13 |
+
import sys
|
| 14 |
+
import argparse
|
| 15 |
+
import numpy as np
|
| 16 |
+
import matplotlib.pyplot as plt
|
| 17 |
+
from sklearn.metrics import precision_recall_fscore_support, accuracy_score, confusion_matrix
|
| 18 |
+
from sklearn.metrics import classification_report
|
| 19 |
+
|
| 20 |
+
def parse_files(algal_file, contaminant_file):
|
| 21 |
+
"""
|
| 22 |
+
Parse the algal and contaminant files to extract true and predicted labels
|
| 23 |
+
|
| 24 |
+
Arguments:
|
| 25 |
+
algal_file (str): Path to the file containing algal sequences
|
| 26 |
+
contaminant_file (str): Path to the file containing contaminant sequences
|
| 27 |
+
|
| 28 |
+
Returns:
|
| 29 |
+
tuple: Lists of true labels and predicted labels
|
| 30 |
+
"""
|
| 31 |
+
true_labels = []
|
| 32 |
+
predicted_labels = []
|
| 33 |
+
sequence_ids = []
|
| 34 |
+
|
| 35 |
+
# Process algal file (all true labels are 'algal')
|
| 36 |
+
with open(algal_file, 'r') as f:
|
| 37 |
+
for line in f:
|
| 38 |
+
line = line.strip()
|
| 39 |
+
if not line:
|
| 40 |
+
continue
|
| 41 |
+
|
| 42 |
+
# Skip header or non-data lines
|
| 43 |
+
if line.startswith('==>') or line.startswith('('): ## or not re.search(r'-|_', line):
|
| 44 |
+
continue
|
| 45 |
+
|
| 46 |
+
# Extract sequence ID
|
| 47 |
+
seq_id_match = re.match(r'^([^\s]+)', line)
|
| 48 |
+
if seq_id_match:
|
| 49 |
+
seq_id = seq_id_match.group(1)
|
| 50 |
+
else:
|
| 51 |
+
seq_id = "unknown_id"
|
| 52 |
+
|
| 53 |
+
# Add to tracking lists
|
| 54 |
+
true_labels.append('algal')
|
| 55 |
+
sequence_ids.append(seq_id)
|
| 56 |
+
|
| 57 |
+
# Determine predicted label based on tags
|
| 58 |
+
if '@' in line:
|
| 59 |
+
predicted_labels.append('algal')
|
| 60 |
+
elif '!' in line:
|
| 61 |
+
predicted_labels.append('contaminant')
|
| 62 |
+
else:
|
| 63 |
+
predicted_labels.append('unknown')
|
| 64 |
+
#if re.search(r'<@+>', line):
|
| 65 |
+
# predicted_labels.append('algal')
|
| 66 |
+
#elif re.search(r'<!+>', line):
|
| 67 |
+
# predicted_labels.append('contaminant')
|
| 68 |
+
#else:
|
| 69 |
+
# predicted_labels.append('unknown')
|
| 70 |
+
|
| 71 |
+
# Process contaminant file (all true labels are 'contaminant')
|
| 72 |
+
with open(contaminant_file, 'r') as f:
|
| 73 |
+
for line in f:
|
| 74 |
+
line = line.strip()
|
| 75 |
+
if not line:
|
| 76 |
+
continue
|
| 77 |
+
|
| 78 |
+
# Skip header or non-data lines
|
| 79 |
+
if line.startswith('==>') or line.startswith('('): ## or not re.search(r'\.|_', line):
|
| 80 |
+
continue
|
| 81 |
+
|
| 82 |
+
# Extract sequence ID
|
| 83 |
+
seq_id_match = re.match(r'^([^\s]+)', line)
|
| 84 |
+
if seq_id_match:
|
| 85 |
+
seq_id = seq_id_match.group(1)
|
| 86 |
+
else:
|
| 87 |
+
seq_id = "unknown_id"
|
| 88 |
+
|
| 89 |
+
# Add to tracking lists
|
| 90 |
+
true_labels.append('contaminant')
|
| 91 |
+
sequence_ids.append(seq_id)
|
| 92 |
+
|
| 93 |
+
# Determine predicted label based on tags
|
| 94 |
+
# if re.search(r'<@+>', line):
|
| 95 |
+
# predicted_labels.append('algal')
|
| 96 |
+
#elif re.search(r'<!+>', line):
|
| 97 |
+
# predicted_labels.append('contaminant')
|
| 98 |
+
#e#lse:
|
| 99 |
+
# predicted_labels.append('unknown')
|
| 100 |
+
# Determine predicted label based on symbols (@ for algal, ! for contaminant)
|
| 101 |
+
if '@' in line:
|
| 102 |
+
predicted_labels.append('algal')
|
| 103 |
+
elif '!' in line:
|
| 104 |
+
predicted_labels.append('contaminant')
|
| 105 |
+
else:
|
| 106 |
+
predicted_labels.append('unknown')
|
| 107 |
+
|
| 108 |
+
return true_labels, predicted_labels, sequence_ids
|
| 109 |
+
|
| 110 |
+
def calculate_metrics(true_labels, predicted_labels):
|
| 111 |
+
"""
|
| 112 |
+
Calculate comprehensive classification metrics
|
| 113 |
+
|
| 114 |
+
Arguments:
|
| 115 |
+
true_labels (list): List of true class labels
|
| 116 |
+
predicted_labels (list): List of predicted class labels
|
| 117 |
+
|
| 118 |
+
Returns:
|
| 119 |
+
dict: Dictionary containing all calculated metrics
|
| 120 |
+
"""
|
| 121 |
+
# Convert labels for sklearn functions
|
| 122 |
+
classes = ['algal', 'contaminant']
|
| 123 |
+
label_map = {label: i for i, label in enumerate(classes)}
|
| 124 |
+
|
| 125 |
+
# Convert to numeric form
|
| 126 |
+
true_numeric = np.array([label_map.get(label, 2) for label in true_labels])
|
| 127 |
+
pred_numeric = np.array([label_map.get(label, 2) for label in predicted_labels])
|
| 128 |
+
|
| 129 |
+
# Filter out unknowns for main metrics
|
| 130 |
+
known_indices = [i for i, pred in enumerate(predicted_labels) if pred != 'unknown']
|
| 131 |
+
true_known = [true_labels[i] for i in known_indices]
|
| 132 |
+
pred_known = [predicted_labels[i] for i in known_indices]
|
| 133 |
+
|
| 134 |
+
# Overall accuracy (including unknowns as wrong predictions)
|
| 135 |
+
accuracy = sum(t == p for t, p in zip(true_labels, predicted_labels)) / len(true_labels)
|
| 136 |
+
|
| 137 |
+
if true_known and pred_known:
|
| 138 |
+
# Convert to numeric
|
| 139 |
+
true_known_numeric = np.array([label_map[label] for label in true_known])
|
| 140 |
+
pred_known_numeric = np.array([label_map[label] for label in pred_known])
|
| 141 |
+
|
| 142 |
+
# Calculate precision, recall, and F1 (excluding unknowns)
|
| 143 |
+
precision, recall, f1, support = precision_recall_fscore_support(
|
| 144 |
+
true_known_numeric,
|
| 145 |
+
pred_known_numeric,
|
| 146 |
+
labels=[0, 1], # algal, contaminant
|
| 147 |
+
zero_division=0
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
# Create confusion matrix
|
| 151 |
+
cm = confusion_matrix(
|
| 152 |
+
true_known_numeric,
|
| 153 |
+
pred_known_numeric,
|
| 154 |
+
labels=[0, 1]
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
# Full classification report
|
| 158 |
+
report = classification_report(
|
| 159 |
+
true_known_numeric,
|
| 160 |
+
pred_known_numeric,
|
| 161 |
+
labels=[0, 1],
|
| 162 |
+
target_names=classes,
|
| 163 |
+
output_dict=True
|
| 164 |
+
)
|
| 165 |
+
else:
|
| 166 |
+
precision = recall = f1 = support = [0, 0]
|
| 167 |
+
cm = np.zeros((2, 2))
|
| 168 |
+
report = {}
|
| 169 |
+
|
| 170 |
+
# Count occurrences and calculate per-class metrics
|
| 171 |
+
class_metrics = {}
|
| 172 |
+
for class_name in classes:
|
| 173 |
+
class_indices = [i for i, label in enumerate(true_labels) if label == class_name]
|
| 174 |
+
total = len(class_indices)
|
| 175 |
+
|
| 176 |
+
if total == 0:
|
| 177 |
+
class_metrics[class_name] = {
|
| 178 |
+
"total": 0,
|
| 179 |
+
"correct": 0,
|
| 180 |
+
"incorrect": 0,
|
| 181 |
+
"unknown": 0,
|
| 182 |
+
"accuracy": 0,
|
| 183 |
+
"error_rate": 0
|
| 184 |
+
}
|
| 185 |
+
continue
|
| 186 |
+
|
| 187 |
+
correct = sum(1 for i in class_indices if predicted_labels[i] == class_name)
|
| 188 |
+
unknown = sum(1 for i in class_indices if predicted_labels[i] == "unknown")
|
| 189 |
+
incorrect = total - correct - unknown
|
| 190 |
+
|
| 191 |
+
class_metrics[class_name] = {
|
| 192 |
+
"total": total,
|
| 193 |
+
"correct": correct,
|
| 194 |
+
"incorrect": incorrect,
|
| 195 |
+
"unknown": unknown,
|
| 196 |
+
"accuracy": correct / total if total > 0 else 0,
|
| 197 |
+
"error_rate": (incorrect + unknown) / total if total > 0 else 0
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
# Compile all metrics
|
| 201 |
+
metrics = {
|
| 202 |
+
"accuracy": accuracy,
|
| 203 |
+
"class_metrics": class_metrics,
|
| 204 |
+
"confusion_matrix": cm,
|
| 205 |
+
"precision": {classes[i]: precision[i] for i in range(len(classes))},
|
| 206 |
+
"recall": {classes[i]: recall[i] for i in range(len(classes))},
|
| 207 |
+
"f1": {classes[i]: f1[i] for i in range(len(classes))},
|
| 208 |
+
"support": {classes[i]: support[i] for i in range(len(classes))},
|
| 209 |
+
"classification_report": report,
|
| 210 |
+
"macro_f1": np.mean(f1),
|
| 211 |
+
"weighted_f1": np.sum(f1 * support) / np.sum(support) if np.sum(support) > 0 else 0,
|
| 212 |
+
"total_samples": len(true_labels),
|
| 213 |
+
"total_correct": sum(t == p for t, p in zip(true_labels, predicted_labels)),
|
| 214 |
+
"total_unknown": predicted_labels.count("unknown")
|
| 215 |
+
}
|
| 216 |
+
|
| 217 |
+
return metrics
|
| 218 |
+
|
| 219 |
+
def display_results(metrics, output_file=None):
|
| 220 |
+
"""
|
| 221 |
+
Display comprehensive results and optionally save to file
|
| 222 |
+
|
| 223 |
+
Arguments:
|
| 224 |
+
metrics (dict): Dictionary containing all calculated metrics
|
| 225 |
+
output_file (str, optional): Path to save results to
|
| 226 |
+
"""
|
| 227 |
+
# Start capturing output if needed
|
| 228 |
+
if output_file:
|
| 229 |
+
import io
|
| 230 |
+
output_capture = io.StringIO()
|
| 231 |
+
original_stdout = sys.stdout
|
| 232 |
+
sys.stdout = output_capture
|
| 233 |
+
|
| 234 |
+
# Print header
|
| 235 |
+
print("\n" + "="*60)
|
| 236 |
+
print(" LLM CLASSIFICATION METRICS REPORT")
|
| 237 |
+
print("="*60)
|
| 238 |
+
|
| 239 |
+
# Overall metrics
|
| 240 |
+
print("\n=== OVERALL METRICS ===")
|
| 241 |
+
print(f"Total samples: {metrics['total_samples']}")
|
| 242 |
+
print(f"Correctly classified: {metrics['total_correct']} ({metrics['total_correct']/metrics['total_samples']*100:.2f}%)")
|
| 243 |
+
print(f"Unknown predictions: {metrics['total_unknown']} ({metrics['total_unknown']/metrics['total_samples']*100:.2f}%)")
|
| 244 |
+
print(f"Overall accuracy: {metrics['accuracy']:.4f}")
|
| 245 |
+
print(f"Macro F1: {metrics['macro_f1']:.4f}")
|
| 246 |
+
print(f"Weighted F1: {metrics['weighted_f1']:.4f}")
|
| 247 |
+
|
| 248 |
+
# Confusion matrix
|
| 249 |
+
cm = metrics["confusion_matrix"]
|
| 250 |
+
class_labels = ["Algal", "Bacterial"]
|
| 251 |
+
|
| 252 |
+
print("\n=== CONFUSION MATRIX ===")
|
| 253 |
+
print(f"{'':15} | {'Predicted Algal':15} | {'Predicted Bacterial':20}")
|
| 254 |
+
print("-" * 55)
|
| 255 |
+
for i, label in enumerate(class_labels):
|
| 256 |
+
print(f"{label:15} | {int(cm[i][0]):15} | {int(cm[i][1]):20}")
|
| 257 |
+
|
| 258 |
+
# Per-class metrics
|
| 259 |
+
print("\n=== PER-CLASS METRICS ===")
|
| 260 |
+
print(f"{'Class':10} | {'Precision':10} | {'Recall':10} | {'F1 Score':10} | {'Support':10}")
|
| 261 |
+
print("-" * 60)
|
| 262 |
+
for class_name in ['algal', 'contaminant']:
|
| 263 |
+
precision = metrics['precision'][class_name]
|
| 264 |
+
recall = metrics['recall'][class_name]
|
| 265 |
+
f1 = metrics['f1'][class_name]
|
| 266 |
+
support = metrics['support'][class_name]
|
| 267 |
+
print(f"{class_name.capitalize():10} | {precision:.4f} | {recall:.4f} | {f1:.4f} | {int(support):10}")
|
| 268 |
+
|
| 269 |
+
# Detailed class counts
|
| 270 |
+
print("\n=== DETAILED CLASS COUNTS ===")
|
| 271 |
+
for class_name, class_data in metrics["class_metrics"].items():
|
| 272 |
+
print(f"{class_name.capitalize()} class:")
|
| 273 |
+
print(f" Total samples: {class_data['total']}")
|
| 274 |
+
if class_data['total'] > 0:
|
| 275 |
+
print(f" Correctly classified: {class_data['correct']} ({class_data['correct']/class_data['total']*100:.2f}%)")
|
| 276 |
+
print(f" Incorrectly classified: {class_data['incorrect']} ({class_data['incorrect']/class_data['total']*100:.2f}%)")
|
| 277 |
+
print(f" Unknown: {class_data['unknown']} ({class_data['unknown']/class_data['total']*100:.2f}%)")
|
| 278 |
+
print()
|
| 279 |
+
|
| 280 |
+
# If saving to file
|
| 281 |
+
if output_file:
|
| 282 |
+
# Restore stdout
|
| 283 |
+
sys.stdout = original_stdout
|
| 284 |
+
|
| 285 |
+
# Write to file
|
| 286 |
+
with open(output_file, 'w') as f:
|
| 287 |
+
f.write(output_capture.getvalue())
|
| 288 |
+
|
| 289 |
+
print(f"Results saved to {output_file}")
|
| 290 |
+
|
| 291 |
+
def generate_visualizations(metrics, output_prefix=None):
|
| 292 |
+
"""
|
| 293 |
+
Generate visualizations of the metrics
|
| 294 |
+
|
| 295 |
+
Arguments:
|
| 296 |
+
metrics (dict): Dictionary containing all calculated metrics
|
| 297 |
+
output_prefix (str, optional): Prefix for output image files
|
| 298 |
+
"""
|
| 299 |
+
# Create confusion matrix heatmap
|
| 300 |
+
plt.figure(figsize=(8, 6))
|
| 301 |
+
cm = metrics["confusion_matrix"]
|
| 302 |
+
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
|
| 303 |
+
plt.title('Confusion Matrix')
|
| 304 |
+
plt.colorbar()
|
| 305 |
+
|
| 306 |
+
classes = ["Algal", "Bacterial"]
|
| 307 |
+
tick_marks = np.arange(len(classes))
|
| 308 |
+
plt.xticks(tick_marks, classes, rotation=45)
|
| 309 |
+
plt.yticks(tick_marks, classes)
|
| 310 |
+
|
| 311 |
+
# Add text annotations
|
| 312 |
+
thresh = cm.max() / 2.0
|
| 313 |
+
for i in range(cm.shape[0]):
|
| 314 |
+
for j in range(cm.shape[1]):
|
| 315 |
+
plt.text(j, i, format(int(cm[i, j]), 'd'),
|
| 316 |
+
horizontalalignment="center",
|
| 317 |
+
color="white" if cm[i, j] > thresh else "black")
|
| 318 |
+
|
| 319 |
+
plt.ylabel('True label')
|
| 320 |
+
plt.xlabel('Predicted label')
|
| 321 |
+
plt.tight_layout()
|
| 322 |
+
|
| 323 |
+
if output_prefix:
|
| 324 |
+
plt.savefig(f"{output_prefix}_confusion_matrix.png", dpi=300, bbox_inches='tight')
|
| 325 |
+
else:
|
| 326 |
+
plt.show()
|
| 327 |
+
|
| 328 |
+
# Create per-class metrics bar chart
|
| 329 |
+
plt.figure(figsize=(10, 6))
|
| 330 |
+
|
| 331 |
+
metrics_names = ['Precision', 'Recall', 'F1-Score']
|
| 332 |
+
x = np.arange(len(metrics_names))
|
| 333 |
+
width = 0.35
|
| 334 |
+
|
| 335 |
+
algal_values = [metrics['precision']['algal'], metrics['recall']['algal'], metrics['f1']['algal']]
|
| 336 |
+
contaminant_values = [metrics['precision']['contaminant'], metrics['recall']['contaminant'], metrics['f1']['contaminant']]
|
| 337 |
+
|
| 338 |
+
plt.bar(x - width/2, algal_values, width, label='Algal')
|
| 339 |
+
plt.bar(x + width/2, contaminant_values, width, label='Bacterial')
|
| 340 |
+
|
| 341 |
+
plt.ylabel('Score')
|
| 342 |
+
plt.title('Performance Metrics by Class')
|
| 343 |
+
plt.xticks(x, metrics_names)
|
| 344 |
+
plt.ylim(0, 1.1)
|
| 345 |
+
plt.legend()
|
| 346 |
+
plt.grid(axis='y', linestyle='--', alpha=0.7)
|
| 347 |
+
|
| 348 |
+
if output_prefix:
|
| 349 |
+
plt.savefig(f"{output_prefix}_metrics_by_class.png", dpi=300, bbox_inches='tight')
|
| 350 |
+
else:
|
| 351 |
+
plt.show()
|
| 352 |
+
|
| 353 |
+
# Create class distribution pie charts
|
| 354 |
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
|
| 355 |
+
|
| 356 |
+
# Algal class distribution
|
| 357 |
+
algal_data = metrics['class_metrics']['algal']
|
| 358 |
+
algal_labels = ['Correct', 'Incorrect', 'Unknown']
|
| 359 |
+
algal_values = [algal_data['correct'], algal_data['incorrect'], algal_data['unknown']]
|
| 360 |
+
ax1.pie(algal_values, labels=algal_labels, autopct='%1.1f%%', startangle=90)
|
| 361 |
+
ax1.set_title('Algal Class Predictions')
|
| 362 |
+
|
| 363 |
+
# Bacterial class distribution
|
| 364 |
+
contaminant_data = metrics['class_metrics']['contaminant']
|
| 365 |
+
contaminant_labels = ['Correct', 'Incorrect', 'Unknown']
|
| 366 |
+
contaminant_values = [contaminant_data['correct'], contaminant_data['incorrect'], contaminant_data['unknown']]
|
| 367 |
+
ax2.pie(contaminant_values, labels=contaminant_labels, autopct='%1.1f%%', startangle=90)
|
| 368 |
+
ax2.set_title('Bacterial Class Predictions')
|
| 369 |
+
|
| 370 |
+
plt.tight_layout()
|
| 371 |
+
|
| 372 |
+
if output_prefix:
|
| 373 |
+
plt.savefig(f"{output_prefix}_class_distribution.png", dpi=300, bbox_inches='tight')
|
| 374 |
+
else:
|
| 375 |
+
plt.show()
|
| 376 |
+
|
| 377 |
+
def create_misclassified_report(true_labels, predicted_labels, sequence_ids, output_file=None):
|
| 378 |
+
"""
|
| 379 |
+
Create a report of misclassified sequences
|
| 380 |
+
|
| 381 |
+
Arguments:
|
| 382 |
+
true_labels (list): List of true class labels
|
| 383 |
+
predicted_labels (list): List of predicted class labels
|
| 384 |
+
sequence_ids (list): List of sequence IDs
|
| 385 |
+
output_file (str, optional): Path to save the report to
|
| 386 |
+
"""
|
| 387 |
+
misclassified = []
|
| 388 |
+
for i, (true, pred, seq_id) in enumerate(zip(true_labels, predicted_labels, sequence_ids)):
|
| 389 |
+
if true != pred:
|
| 390 |
+
misclassified.append({
|
| 391 |
+
'id': seq_id,
|
| 392 |
+
'true': true,
|
| 393 |
+
'predicted': pred
|
| 394 |
+
})
|
| 395 |
+
|
| 396 |
+
# Start capturing output
|
| 397 |
+
if output_file:
|
| 398 |
+
import io
|
| 399 |
+
output_capture = io.StringIO()
|
| 400 |
+
original_stdout = sys.stdout
|
| 401 |
+
sys.stdout = output_capture
|
| 402 |
+
|
| 403 |
+
# Print header
|
| 404 |
+
print("\n" + "="*60)
|
| 405 |
+
print(" MISCLASSIFIED SEQUENCES REPORT")
|
| 406 |
+
print("="*60)
|
| 407 |
+
print(f"\nTotal misclassified: {len(misclassified)} out of {len(true_labels)} ({len(misclassified)/len(true_labels)*100:.2f}%)\n")
|
| 408 |
+
|
| 409 |
+
# Print algal sequences misclassified as contaminant
|
| 410 |
+
print("\n--- ALGAL SEQUENCES MISCLASSIFIED AS BACTERIAL ---")
|
| 411 |
+
algal_as_contaminant = [m for m in misclassified if m['true'] == 'algal' and m['predicted'] == 'contaminant']
|
| 412 |
+
for item in algal_as_contaminant:
|
| 413 |
+
print(f"ID: {item['id']}")
|
| 414 |
+
print(f"Total: {len(algal_as_contaminant)}")
|
| 415 |
+
|
| 416 |
+
# Print contaminant sequences misclassified as algal
|
| 417 |
+
print("\n--- BACTERIAL SEQUENCES MISCLASSIFIED AS ALGAL ---")
|
| 418 |
+
contaminant_as_algal = [m for m in misclassified if m['true'] == 'contaminant' and m['predicted'] == 'algal']
|
| 419 |
+
for item in contaminant_as_algal:
|
| 420 |
+
print(f"ID: {item['id']}")
|
| 421 |
+
print(f"Total: {len(contaminant_as_algal)}")
|
| 422 |
+
|
| 423 |
+
# Print unknown classifications
|
| 424 |
+
print("\n--- SEQUENCES WITH UNKNOWN CLASSIFICATION ---")
|
| 425 |
+
unknown = [m for m in misclassified if m['predicted'] == 'unknown']
|
| 426 |
+
for item in unknown:
|
| 427 |
+
print(f"ID: {item['id']} (True: {item['true']})")
|
| 428 |
+
print(f"Total: {len(unknown)}")
|
| 429 |
+
|
| 430 |
+
# If saving to file
|
| 431 |
+
if output_file:
|
| 432 |
+
# Restore stdout
|
| 433 |
+
sys.stdout = original_stdout
|
| 434 |
+
|
| 435 |
+
# Write to file
|
| 436 |
+
with open(output_file, 'w') as f:
|
| 437 |
+
f.write(output_capture.getvalue())
|
| 438 |
+
|
| 439 |
+
print(f"Misclassified report saved to {output_file}")
|
| 440 |
+
|
| 441 |
+
def main():
|
| 442 |
+
"""Main function to run the script"""
|
| 443 |
+
parser = argparse.ArgumentParser(description='LLM Classification Metrics Generator for Two-File Analysis')
|
| 444 |
+
parser.add_argument('algal_file', help='Path to the file containing algal sequences')
|
| 445 |
+
parser.add_argument('contaminant_file', help='Path to the file containing contaminant sequences')
|
| 446 |
+
parser.add_argument('-o', '--output', help='Path to save the metrics report')
|
| 447 |
+
parser.add_argument('-m', '--misclassified', help='Path to save the misclassified sequences report')
|
| 448 |
+
parser.add_argument('-v', '--visualize', action='store_true', help='Generate visualizations')
|
| 449 |
+
parser.add_argument('-p', '--prefix', default='llm_metrics', help='Prefix for output files')
|
| 450 |
+
|
| 451 |
+
args = parser.parse_args()
|
| 452 |
+
|
| 453 |
+
# Parse files and calculate metrics
|
| 454 |
+
true_labels, predicted_labels, sequence_ids = parse_files(args.algal_file, args.contaminant_file)
|
| 455 |
+
metrics = calculate_metrics(true_labels, predicted_labels)
|
| 456 |
+
|
| 457 |
+
# Display results
|
| 458 |
+
output_file = f"{args.prefix}_report.txt" if args.output else None
|
| 459 |
+
display_results(metrics, output_file)
|
| 460 |
+
|
| 461 |
+
# Generate visualizations if requested
|
| 462 |
+
if args.visualize:
|
| 463 |
+
generate_visualizations(metrics, args.prefix)
|
| 464 |
+
|
| 465 |
+
# Create misclassified report if requested
|
| 466 |
+
if args.misclassified:
|
| 467 |
+
misclassified_file = f"{args.prefix}_misclassified.txt" if args.misclassified is True else args.misclassified
|
| 468 |
+
create_misclassified_report(true_labels, predicted_labels, sequence_ids, misclassified_file)
|
| 469 |
+
|
| 470 |
+
# Return number of misclassifications (for automated testing)
|
| 471 |
+
misclassifications = sum(t != p for t, p in zip(true_labels, predicted_labels))
|
| 472 |
+
return misclassifications
|
| 473 |
+
|
| 474 |
+
if __name__ == "__main__":
|
| 475 |
+
sys.exit(main())
|
meta.pkl
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:75049dcc1458196491b007a46f4aa1acef9ca3c27df282f226af572877203752
|
| 3 |
+
size 14858
|
model.py
ADDED
|
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Full definition of a GPT Language Model, all of it in this single file.
|
| 3 |
+
References:
|
| 4 |
+
1) the official GPT-2 TensorFlow implementation released by OpenAI:
|
| 5 |
+
https://github.com/openai/gpt-2/blob/master/src/model.py
|
| 6 |
+
2) huggingface/transformers PyTorch implementation:
|
| 7 |
+
https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import math
|
| 11 |
+
import inspect
|
| 12 |
+
from dataclasses import dataclass
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
from torch.nn import functional as F
|
| 17 |
+
|
| 18 |
+
class LayerNorm(nn.Module):
|
| 19 |
+
""" LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """
|
| 20 |
+
|
| 21 |
+
def __init__(self, ndim, bias):
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.weight = nn.Parameter(torch.ones(ndim))
|
| 24 |
+
self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
|
| 25 |
+
|
| 26 |
+
def forward(self, input):
|
| 27 |
+
return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
|
| 28 |
+
|
| 29 |
+
class CausalSelfAttention(nn.Module):
|
| 30 |
+
|
| 31 |
+
def __init__(self, config):
|
| 32 |
+
super().__init__()
|
| 33 |
+
assert config.n_embd % config.n_head == 0
|
| 34 |
+
# key, query, value projections for all heads, but in a batch
|
| 35 |
+
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
|
| 36 |
+
# output projection
|
| 37 |
+
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
|
| 38 |
+
# regularization
|
| 39 |
+
self.attn_dropout = nn.Dropout(config.dropout)
|
| 40 |
+
self.resid_dropout = nn.Dropout(config.dropout)
|
| 41 |
+
self.n_head = config.n_head
|
| 42 |
+
self.n_embd = config.n_embd
|
| 43 |
+
self.dropout = config.dropout
|
| 44 |
+
# flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
|
| 45 |
+
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
|
| 46 |
+
if not self.flash:
|
| 47 |
+
print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
|
| 48 |
+
# causal mask to ensure that attention is only applied to the left in the input sequence
|
| 49 |
+
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
|
| 50 |
+
.view(1, 1, config.block_size, config.block_size))
|
| 51 |
+
|
| 52 |
+
def forward(self, x):
|
| 53 |
+
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
|
| 54 |
+
|
| 55 |
+
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
| 56 |
+
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
|
| 57 |
+
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
| 58 |
+
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
| 59 |
+
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
| 60 |
+
|
| 61 |
+
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
| 62 |
+
if self.flash:
|
| 63 |
+
# efficient attention using Flash Attention CUDA kernels
|
| 64 |
+
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
|
| 65 |
+
else:
|
| 66 |
+
# manual implementation of attention
|
| 67 |
+
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
| 68 |
+
att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
|
| 69 |
+
att = F.softmax(att, dim=-1)
|
| 70 |
+
att = self.attn_dropout(att)
|
| 71 |
+
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
| 72 |
+
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
|
| 73 |
+
|
| 74 |
+
# output projection
|
| 75 |
+
y = self.resid_dropout(self.c_proj(y))
|
| 76 |
+
return y
|
| 77 |
+
|
| 78 |
+
class MLP(nn.Module):
|
| 79 |
+
|
| 80 |
+
def __init__(self, config):
|
| 81 |
+
super().__init__()
|
| 82 |
+
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
|
| 83 |
+
self.gelu = nn.GELU()
|
| 84 |
+
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
|
| 85 |
+
self.dropout = nn.Dropout(config.dropout)
|
| 86 |
+
|
| 87 |
+
def forward(self, x):
|
| 88 |
+
x = self.c_fc(x)
|
| 89 |
+
x = self.gelu(x)
|
| 90 |
+
x = self.c_proj(x)
|
| 91 |
+
x = self.dropout(x)
|
| 92 |
+
return x
|
| 93 |
+
|
| 94 |
+
class Block(nn.Module):
|
| 95 |
+
|
| 96 |
+
def __init__(self, config):
|
| 97 |
+
super().__init__()
|
| 98 |
+
self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
|
| 99 |
+
self.attn = CausalSelfAttention(config)
|
| 100 |
+
self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
|
| 101 |
+
self.mlp = MLP(config)
|
| 102 |
+
|
| 103 |
+
def forward(self, x):
|
| 104 |
+
x = x + self.attn(self.ln_1(x))
|
| 105 |
+
x = x + self.mlp(self.ln_2(x))
|
| 106 |
+
return x
|
| 107 |
+
|
| 108 |
+
@dataclass
|
| 109 |
+
class GPTConfig:
|
| 110 |
+
block_size: int = 1024
|
| 111 |
+
vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
|
| 112 |
+
n_layer: int = 12
|
| 113 |
+
n_head: int = 12
|
| 114 |
+
n_embd: int = 768
|
| 115 |
+
dropout: float = 0.0
|
| 116 |
+
bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
|
| 117 |
+
|
| 118 |
+
class GPT(nn.Module):
|
| 119 |
+
|
| 120 |
+
def __init__(self, config):
|
| 121 |
+
super().__init__()
|
| 122 |
+
assert config.vocab_size is not None
|
| 123 |
+
assert config.block_size is not None
|
| 124 |
+
self.config = config
|
| 125 |
+
|
| 126 |
+
self.transformer = nn.ModuleDict(dict(
|
| 127 |
+
wte = nn.Embedding(config.vocab_size, config.n_embd),
|
| 128 |
+
wpe = nn.Embedding(config.block_size, config.n_embd),
|
| 129 |
+
drop = nn.Dropout(config.dropout),
|
| 130 |
+
h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
|
| 131 |
+
ln_f = LayerNorm(config.n_embd, bias=config.bias),
|
| 132 |
+
))
|
| 133 |
+
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
| 134 |
+
# with weight tying when using torch.compile() some warnings get generated:
|
| 135 |
+
# "UserWarning: functional_call was passed multiple values for tied weights.
|
| 136 |
+
# This behavior is deprecated and will be an error in future versions"
|
| 137 |
+
# not 100% sure what this is, so far seems to be harmless. TODO investigate
|
| 138 |
+
self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying
|
| 139 |
+
|
| 140 |
+
# init all weights
|
| 141 |
+
self.apply(self._init_weights)
|
| 142 |
+
# apply special scaled init to the residual projections, per GPT-2 paper
|
| 143 |
+
for pn, p in self.named_parameters():
|
| 144 |
+
if pn.endswith('c_proj.weight'):
|
| 145 |
+
torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
|
| 146 |
+
|
| 147 |
+
# report number of parameters
|
| 148 |
+
print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))
|
| 149 |
+
|
| 150 |
+
def get_num_params(self, non_embedding=True):
|
| 151 |
+
"""
|
| 152 |
+
Return the number of parameters in the model.
|
| 153 |
+
For non-embedding count (default), the position embeddings get subtracted.
|
| 154 |
+
The token embeddings would too, except due to the parameter sharing these
|
| 155 |
+
params are actually used as weights in the final layer, so we include them.
|
| 156 |
+
"""
|
| 157 |
+
n_params = sum(p.numel() for p in self.parameters())
|
| 158 |
+
if non_embedding:
|
| 159 |
+
n_params -= self.transformer.wpe.weight.numel()
|
| 160 |
+
return n_params
|
| 161 |
+
|
| 162 |
+
def _init_weights(self, module):
|
| 163 |
+
if isinstance(module, nn.Linear):
|
| 164 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 165 |
+
if module.bias is not None:
|
| 166 |
+
torch.nn.init.zeros_(module.bias)
|
| 167 |
+
elif isinstance(module, nn.Embedding):
|
| 168 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 169 |
+
|
| 170 |
+
def forward(self, idx, targets=None):
|
| 171 |
+
device = idx.device
|
| 172 |
+
b, t = idx.size()
|
| 173 |
+
assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
|
| 174 |
+
pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)
|
| 175 |
+
|
| 176 |
+
# forward the GPT model itself
|
| 177 |
+
tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
|
| 178 |
+
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
|
| 179 |
+
x = self.transformer.drop(tok_emb + pos_emb)
|
| 180 |
+
for block in self.transformer.h:
|
| 181 |
+
x = block(x)
|
| 182 |
+
x = self.transformer.ln_f(x)
|
| 183 |
+
|
| 184 |
+
if targets is not None:
|
| 185 |
+
# if we are given some desired targets also calculate the loss
|
| 186 |
+
logits = self.lm_head(x)
|
| 187 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
|
| 188 |
+
else:
|
| 189 |
+
# inference-time mini-optimization: only forward the lm_head on the very last position
|
| 190 |
+
logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
|
| 191 |
+
loss = None
|
| 192 |
+
|
| 193 |
+
return logits, loss
|
| 194 |
+
|
| 195 |
+
def crop_block_size(self, block_size):
|
| 196 |
+
# model surgery to decrease the block size if necessary
|
| 197 |
+
# e.g. we may load the GPT2 pretrained model checkpoint (block size 1024)
|
| 198 |
+
# but want to use a smaller block size for some smaller, simpler model
|
| 199 |
+
assert block_size <= self.config.block_size
|
| 200 |
+
self.config.block_size = block_size
|
| 201 |
+
self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size])
|
| 202 |
+
for block in self.transformer.h:
|
| 203 |
+
if hasattr(block.attn, 'bias'):
|
| 204 |
+
block.attn.bias = block.attn.bias[:,:,:block_size,:block_size]
|
| 205 |
+
|
| 206 |
+
@classmethod
|
| 207 |
+
def from_pretrained(cls, model_type, override_args=None):
|
| 208 |
+
assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
|
| 209 |
+
override_args = override_args or {} # default to empty dict
|
| 210 |
+
# only dropout can be overridden see more notes below
|
| 211 |
+
assert all(k == 'dropout' for k in override_args)
|
| 212 |
+
from transformers import GPT2LMHeadModel
|
| 213 |
+
print("loading weights from pretrained gpt: %s" % model_type)
|
| 214 |
+
|
| 215 |
+
# n_layer, n_head and n_embd are determined from model_type
|
| 216 |
+
config_args = {
|
| 217 |
+
'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params
|
| 218 |
+
'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
|
| 219 |
+
'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
|
| 220 |
+
'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
|
| 221 |
+
}[model_type]
|
| 222 |
+
print("forcing vocab_size=50257, block_size=1024, bias=True")
|
| 223 |
+
config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints
|
| 224 |
+
config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints
|
| 225 |
+
config_args['bias'] = True # always True for GPT model checkpoints
|
| 226 |
+
# we can override the dropout rate, if desired
|
| 227 |
+
if 'dropout' in override_args:
|
| 228 |
+
print(f"overriding dropout rate to {override_args['dropout']}")
|
| 229 |
+
config_args['dropout'] = override_args['dropout']
|
| 230 |
+
# create a from-scratch initialized minGPT model
|
| 231 |
+
config = GPTConfig(**config_args)
|
| 232 |
+
model = GPT(config)
|
| 233 |
+
sd = model.state_dict()
|
| 234 |
+
sd_keys = sd.keys()
|
| 235 |
+
sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param
|
| 236 |
+
|
| 237 |
+
# init a huggingface/transformers model
|
| 238 |
+
model_hf = GPT2LMHeadModel.from_pretrained(model_type)
|
| 239 |
+
sd_hf = model_hf.state_dict()
|
| 240 |
+
|
| 241 |
+
# copy while ensuring all of the parameters are aligned and match in names and shapes
|
| 242 |
+
sd_keys_hf = sd_hf.keys()
|
| 243 |
+
sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer
|
| 244 |
+
sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer)
|
| 245 |
+
transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
|
| 246 |
+
# basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear
|
| 247 |
+
# this means that we have to transpose these weights when we import them
|
| 248 |
+
assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
|
| 249 |
+
for k in sd_keys_hf:
|
| 250 |
+
if any(k.endswith(w) for w in transposed):
|
| 251 |
+
# special treatment for the Conv1D weights we need to transpose
|
| 252 |
+
assert sd_hf[k].shape[::-1] == sd[k].shape
|
| 253 |
+
with torch.no_grad():
|
| 254 |
+
sd[k].copy_(sd_hf[k].t())
|
| 255 |
+
else:
|
| 256 |
+
# vanilla copy over the other parameters
|
| 257 |
+
assert sd_hf[k].shape == sd[k].shape
|
| 258 |
+
with torch.no_grad():
|
| 259 |
+
sd[k].copy_(sd_hf[k])
|
| 260 |
+
|
| 261 |
+
return model
|
| 262 |
+
|
| 263 |
+
def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
|
| 264 |
+
# start with all of the candidate parameters
|
| 265 |
+
param_dict = {pn: p for pn, p in self.named_parameters()}
|
| 266 |
+
# filter out those that do not require grad
|
| 267 |
+
param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
|
| 268 |
+
# create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
|
| 269 |
+
# i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
|
| 270 |
+
decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
|
| 271 |
+
nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
|
| 272 |
+
optim_groups = [
|
| 273 |
+
{'params': decay_params, 'weight_decay': weight_decay},
|
| 274 |
+
{'params': nodecay_params, 'weight_decay': 0.0}
|
| 275 |
+
]
|
| 276 |
+
num_decay_params = sum(p.numel() for p in decay_params)
|
| 277 |
+
num_nodecay_params = sum(p.numel() for p in nodecay_params)
|
| 278 |
+
print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
|
| 279 |
+
print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
|
| 280 |
+
# Create AdamW optimizer and use the fused version if it is available
|
| 281 |
+
fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
|
| 282 |
+
use_fused = fused_available and device_type == 'cuda'
|
| 283 |
+
extra_args = dict(fused=True) if use_fused else dict()
|
| 284 |
+
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
|
| 285 |
+
print(f"using fused AdamW: {use_fused}")
|
| 286 |
+
|
| 287 |
+
return optimizer
|
| 288 |
+
|
| 289 |
+
def estimate_mfu(self, fwdbwd_per_iter, dt):
|
| 290 |
+
""" estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """
|
| 291 |
+
# first estimate the number of flops we do per iteration.
|
| 292 |
+
# see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311
|
| 293 |
+
N = self.get_num_params()
|
| 294 |
+
cfg = self.config
|
| 295 |
+
L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size
|
| 296 |
+
flops_per_token = 6*N + 12*L*H*Q*T
|
| 297 |
+
flops_per_fwdbwd = flops_per_token * T
|
| 298 |
+
flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
|
| 299 |
+
# express our flops throughput as ratio of A100 bfloat16 peak flops
|
| 300 |
+
flops_achieved = flops_per_iter * (1.0/dt) # per second
|
| 301 |
+
flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS
|
| 302 |
+
mfu = flops_achieved / flops_promised
|
| 303 |
+
return mfu
|
| 304 |
+
|
| 305 |
+
@torch.no_grad()
|
| 306 |
+
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
|
| 307 |
+
"""
|
| 308 |
+
Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
|
| 309 |
+
the sequence max_new_tokens times, feeding the predictions back into the model each time.
|
| 310 |
+
Most likely you'll want to make sure to be in model.eval() mode of operation for this.
|
| 311 |
+
"""
|
| 312 |
+
for _ in range(max_new_tokens):
|
| 313 |
+
# if the sequence context is growing too long we must crop it at block_size
|
| 314 |
+
idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
|
| 315 |
+
# forward the model to get the logits for the index in the sequence
|
| 316 |
+
logits, _ = self(idx_cond)
|
| 317 |
+
# pluck the logits at the final step and scale by desired temperature
|
| 318 |
+
logits = logits[:, -1, :] / temperature
|
| 319 |
+
# optionally crop the logits to only the top k options
|
| 320 |
+
if top_k is not None:
|
| 321 |
+
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
| 322 |
+
logits[logits < v[:, [-1]]] = -float('Inf')
|
| 323 |
+
# apply softmax to convert logits to (normalized) probabilities
|
| 324 |
+
probs = F.softmax(logits, dim=-1)
|
| 325 |
+
# sample from the distribution
|
| 326 |
+
idx_next = torch.multinomial(probs, num_samples=1)
|
| 327 |
+
# append sampled index to the running sequence and continue
|
| 328 |
+
idx = torch.cat((idx, idx_next), dim=1)
|
| 329 |
+
|
| 330 |
+
return idx
|
run_la4sr_TI-inc-algaGPT.sh
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
###############################################################################
|
| 4 |
+
# run_la4sr_TI-inc.sh — LA4SR pipeline: model inference (FASTA→TSV) + metrics
|
| 5 |
+
###############################################################################
|
| 6 |
+
|
| 7 |
+
# ---------------------- Configuration ----------------------
|
| 8 |
+
SCRIPT_DIR="$(pwd)"
|
| 9 |
+
INFER_SCRIPT="$SCRIPT_DIR/infer_TI-inc-algaGPT.py" # <— new script
|
| 10 |
+
METRICS_SCRIPT="$SCRIPT_DIR/llm-metrics-two-files.py"
|
| 11 |
+
SIF="$SCRIPT_DIR/la4sr_sp2.sif"
|
| 12 |
+
|
| 13 |
+
# cache for HF tokenizers & models
|
| 14 |
+
tcache="$SCRIPT_DIR/cache"
|
| 15 |
+
mkdir -p "$tcache"
|
| 16 |
+
|
| 17 |
+
# ---------------------- Usage ----------------------
|
| 18 |
+
if [[ $# -ne 3 ]]; then
|
| 19 |
+
cat <<EOF
|
| 20 |
+
Usage: $(basename "$0") <model_name|resume> <algal_fasta> <bacterial_fasta>
|
| 21 |
+
|
| 22 |
+
If you pass the literal word resume the script loads ckpt.pt + meta.pkl
|
| 23 |
+
from the current directory. Otherwise the value is forwarded to --init_from
|
| 24 |
+
(e.g. GreenGenomicsLab/LA4SR-gpt-neo125-ALMGA-FL).
|
| 25 |
+
|
| 26 |
+
EOF
|
| 27 |
+
exit 1
|
| 28 |
+
fi
|
| 29 |
+
|
| 30 |
+
MODEL_NAME="$1" # "resume" OR HF repo / local path
|
| 31 |
+
algal_fasta="$2"
|
| 32 |
+
bact_fasta="$3"
|
| 33 |
+
|
| 34 |
+
prefix="$(basename "${algal_fasta%.*}")_vs_$(basename "${bact_fasta%.*}")"
|
| 35 |
+
mkdir -p results
|
| 36 |
+
alg_out="results/${prefix}_algal.tsv"
|
| 37 |
+
bac_out="results/${prefix}_bacterial.tsv"
|
| 38 |
+
alg_out_tagged="results/${prefix}_algal_tagged.tsv"
|
| 39 |
+
bac_out_tagged="results/${prefix}_bacterial_tagged.tsv"
|
| 40 |
+
report="results/${prefix}_report.txt"
|
| 41 |
+
miscl="results/${prefix}_misclassified.txt"
|
| 42 |
+
|
| 43 |
+
# ---------------------- Inference ----------------------
|
| 44 |
+
|
| 45 |
+
run_infer () {
|
| 46 |
+
local fasta=$1 out=$2
|
| 47 |
+
echo -e "\n→ Inferring $(basename "$fasta")..."
|
| 48 |
+
|
| 49 |
+
# Build common args
|
| 50 |
+
PY_ARGS=( --init_from "$MODEL_NAME" )
|
| 51 |
+
[[ "$MODEL_NAME" == "resume" ]] && PY_ARGS+=( --out_dir /workdir )
|
| 52 |
+
|
| 53 |
+
singularity exec --nv \
|
| 54 |
+
-B "$fasta:/input.fasta" \
|
| 55 |
+
-B "$(pwd):/workdir" \
|
| 56 |
+
-B "$tcache:$tcache" \
|
| 57 |
+
--env TRANSFORMERS_CACHE="$tcache" \
|
| 58 |
+
"$SIF" \
|
| 59 |
+
bash -c 'cd /workdir && \
|
| 60 |
+
python3 infer_TI-inc-algaGPT.py '"${PY_ARGS[*]}"' /input.fasta -o "'"$out"'"'
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
#run_infer () {
|
| 64 |
+
#local fasta=$1 out=$2
|
| 65 |
+
#echo -e "\n→ Inferring $(basename "$fasta")..."
|
| 66 |
+
|
| 67 |
+
#PY_ARGS=( --init_from "$MODEL_NAME" )
|
| 68 |
+
#[[ "$MODEL_NAME" == "resume" ]] && PY_ARGS+=( --out_dir /workdir )
|
| 69 |
+
|
| 70 |
+
# singularity exec --nv \
|
| 71 |
+
# -B "$fasta:/input.fasta" \
|
| 72 |
+
# -B "$(pwd):/workdir" \ # <— whole project goes in
|
| 73 |
+
# -B "$tcache:$tcache" \
|
| 74 |
+
# --env TRANSFORMERS_CACHE="$tcache" \
|
| 75 |
+
# "$SIF" \
|
| 76 |
+
# bash -c "cd /workdir && \
|
| 77 |
+
# python3 "$INFER_SCRIPT" \
|
| 78 |
+
# \"${PY_ARGS[@]}\" /input.fasta -o \"$out\""
|
| 79 |
+
#}
|
| 80 |
+
|
| 81 |
+
#run_infer () {
|
| 82 |
+
# local fasta=$1 out=$2
|
| 83 |
+
#echo -e "\n→ Inferring $(basename "$fasta")..."
|
| 84 |
+
|
| 85 |
+
# build python arg list: --init_from ... [--out_dir PWD]
|
| 86 |
+
#PY_ARGS=( --init_from "$MODEL_NAME" )
|
| 87 |
+
#[[ "$MODEL_NAME" == "resume" ]] && PY_ARGS+=( --out_dir "$SCRIPT_DIR" )
|
| 88 |
+
|
| 89 |
+
#if [[ -f "$SIF" ]]; then
|
| 90 |
+
# singularity exec --nv \
|
| 91 |
+
# -B "$fasta:/input.fasta" \
|
| 92 |
+
# -B "$INFER_SCRIPT:/infer.py" \
|
| 93 |
+
#-B "$(pwd):/workdir" \
|
| 94 |
+
#-B "$tcache:$tcache" \
|
| 95 |
+
#--env TRANSFORMERS_CACHE="$tcache" \
|
| 96 |
+
#"$SIF" \
|
| 97 |
+
#python3 /infer.py "${PY_ARGS[@]}" /input.fasta \
|
| 98 |
+
# -o "/workdir/$out"
|
| 99 |
+
#else
|
| 100 |
+
# TRANSFORMERS_CACHE="$tcache" \
|
| 101 |
+
#python3 "$INFER_SCRIPT" "${PY_ARGS[@]}" "$fasta" \
|
| 102 |
+
# -o "$out"
|
| 103 |
+
#fi
|
| 104 |
+
|
| 105 |
+
#echo " ✔ Wrote $out"
|
| 106 |
+
#}
|
| 107 |
+
|
| 108 |
+
run_infer "$algal_fasta" "$alg_out"
|
| 109 |
+
run_infer "$bact_fasta" "$bac_out"
|
| 110 |
+
|
| 111 |
+
# ---------------------- Post-process Tags ----------------------
|
| 112 |
+
convert_tags () {
|
| 113 |
+
local infile=$1 outfile=$2
|
| 114 |
+
echo -e "\n→ Converting 'algae'→@ and 'conta'→! in $(basename "$infile")..."
|
| 115 |
+
sed -E 's/algae/@/g; s/conta/!/g' "$infile" > "$outfile"
|
| 116 |
+
echo " ✔ Wrote $outfile"
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
convert_tags "$alg_out" "$alg_out_tagged"
|
| 120 |
+
convert_tags "$bac_out" "$bac_out_tagged"
|
| 121 |
+
|
| 122 |
+
# ---------------------- Metrics ----------------------
|
| 123 |
+
echo -e "\n→ Generating metrics report..."
|
| 124 |
+
singularity exec \
|
| 125 |
+
-B "$METRICS_SCRIPT:/metrics.py" \
|
| 126 |
+
-B "$(pwd):/workdir" \
|
| 127 |
+
"$SIF" \
|
| 128 |
+
bash -c "source /opt/conda/etc/profile.d/conda.sh && \
|
| 129 |
+
conda activate la4sr && cd /workdir && \
|
| 130 |
+
python3 /metrics.py \
|
| 131 |
+
\"$alg_out_tagged\" \"$bac_out_tagged\" \
|
| 132 |
+
-o \"$report\" \
|
| 133 |
+
-m \"$miscl\" \
|
| 134 |
+
-v \
|
| 135 |
+
-p \"results/$prefix\""
|
| 136 |
+
|
| 137 |
+
# ---------------------- Finished ----------------------
|
| 138 |
+
echo -e '\n🎉 Done! Results in ./results/'
|
| 139 |
+
echo " Algal TSV: $alg_out_tagged"
|
| 140 |
+
echo " Bact TSV: $bac_out_tagged"
|
| 141 |
+
echo " Report: $report"
|
| 142 |
+
echo " Misclassified: $miscl"
|
| 143 |
+
echo " Plots: results/${prefix}_*.png"
|
| 144 |
+
|
run_la4sr_loop.sbatch
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
#SBATCH -o slurm-logs/arrayJob_%A_%a.out
|
| 4 |
+
#SBATCH -e slurm-logs/arrayJob_%A_%a.err
|
| 5 |
+
#SBATCH -a 1-12 #5-112 # <-- set to length of the *longer* file
|
| 6 |
+
#SBATCH --mem=40G
|
| 7 |
+
#SBATCH --time=12:00:00
|
| 8 |
+
#SBATCH -p nvidia
|
| 9 |
+
#SBATCH --gres=gpu:1
|
| 10 |
+
#SBATCH --cpus-per-task=20
|
| 11 |
+
|
| 12 |
+
# Get line count of each file
|
| 13 |
+
NUM_ALGAE=$(wc -l < algae-filelist.txt)
|
| 14 |
+
NUM_CONTAM=$(wc -l < contam-filelist.txt)
|
| 15 |
+
|
| 16 |
+
# Use raw SLURM task ID
|
| 17 |
+
TASK_ID=$SLURM_ARRAY_TASK_ID
|
| 18 |
+
|
| 19 |
+
# Modulo wrap if needed
|
| 20 |
+
IDX_ALGAE=$(( (TASK_ID - 1) % NUM_ALGAE + 1 ))
|
| 21 |
+
IDX_CONTAM=$(( (TASK_ID - 1) % NUM_CONTAM + 1 ))
|
| 22 |
+
|
| 23 |
+
# Extract lines from files
|
| 24 |
+
ALINE=$(sed -n "${IDX_ALGAE}p" algae-filelist.txt)
|
| 25 |
+
CLINE=$(sed -n "${IDX_CONTAM}p" contam-filelist.txt)
|
| 26 |
+
|
| 27 |
+
# Run your classification script
|
| 28 |
+
./run_la4sr_TI-inc-algaGPT.sh resume "$ALINE" "$CLINE"
|
| 29 |
+
|
| 30 |
+
## EXAMPLE:
|
| 31 |
+
|
| 32 |
+
##./run_la4sr.sh ./test-data/TI-free/AlgalTop10000-10holdout-headed.fa ./test-data/TI-free/BactTop10000-10holdout-headed.fa
|
slurm-10718799.out
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
algae-filelist.txt
|
| 2 |
+
cache/
|
| 3 |
+
ckpt.pt
|
| 4 |
+
contam-filelist.txt
|
| 5 |
+
filelist.txt
|
| 6 |
+
generated_prompts_algae1.txt_headed.fa
|
| 7 |
+
generated_prompts_algae2.txt_headed.fa
|
| 8 |
+
generated_prompts_algae3.txt_headed.fa
|
| 9 |
+
generated_prompts_archa1.txt_headed.fa
|
| 10 |
+
generated_prompts_archa2.txt_headed.fa
|
| 11 |
+
generated_prompts_archa3.txt_headed.fa
|
| 12 |
+
generated_prompts_bact1.txt_headed.fa
|
| 13 |
+
generated_prompts_bact2.txt_headed.fa
|
| 14 |
+
generated_prompts_bact3.txt_headed.fa
|
| 15 |
+
generated_prompts_fungi1.txt_headed.fa
|
| 16 |
+
generated_prompts_fungi2.txt_headed.fa
|
| 17 |
+
generated_prompts_fungi3.txt_headed.fa
|
| 18 |
+
generated_prompts_virus1.txt_headed.fa
|
| 19 |
+
generated_prompts_virus2.txt_headed.fa
|
| 20 |
+
generated_prompts_virus3.txt_headed.fa
|
| 21 |
+
infer_TI-inc-algaGPT.py
|
| 22 |
+
la4sr_sp2.sif
|
| 23 |
+
la4sr_sp2.sif.md5
|
| 24 |
+
llm-metrics-two-files.py
|
| 25 |
+
meta.pkl
|
| 26 |
+
model.py
|
| 27 |
+
out-algaGPT/
|
| 28 |
+
__pycache__/
|
| 29 |
+
__pycache__/model.cpython-312.pyc
|
| 30 |
+
README.txt
|
| 31 |
+
results-archive/
|
| 32 |
+
results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_bact3.txt_headed_misclassified.txt
|
| 33 |
+
results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_virus3.txt_headed_bacterial_tagged.tsv
|
| 34 |
+
results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_fungi2.txt_headed_metrics_by_class.png
|
| 35 |
+
results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_archa2.txt_headed_confusion_matrix.png
|
| 36 |
+
results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_bact3.txt_headed_algal_tagged.tsv
|
| 37 |
+
results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_bact2.txt_headed_report.txt
|
| 38 |
+
results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_fungi3.txt_headed_algal.tsv
|
| 39 |
+
results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_bact1.txt_headed_algal_tagged.tsv
|
| 40 |
+
results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_archa3.txt_headed_metrics_by_class.png
|
| 41 |
+
results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_fungi3.txt_headed_report.txt
|
| 42 |
+
results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_virus3.txt_headed_bacterial.tsv
|
| 43 |
+
results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_virus2.txt_headed_bacterial.tsv
|
| 44 |
+
results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_virus3.txt_headed_algal.tsv
|
| 45 |
+
results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_bact1.txt_headed_metrics_by_class.png
|
| 46 |
+
results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_bact1.txt_headed_bacterial.tsv
|
| 47 |
+
results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_fungi2.txt_headed_class_distribution.png
|
| 48 |
+
results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_archa2.txt_headed_metrics_by_class.png
|
| 49 |
+
results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_archa1.txt_headed_misclassified.txt
|
| 50 |
+
results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_bact2.txt_headed_algal.tsv
|
| 51 |
+
results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_virus1.txt_headed_bacterial.tsv
|
| 52 |
+
results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_virus1.txt_headed_metrics_by_class.png
|
| 53 |
+
results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_virus1.txt_headed_class_distribution.png
|
| 54 |
+
results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_archa1.txt_headed_bacterial_tagged.tsv
|
| 55 |
+
results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_bact3.txt_headed_metrics_by_class.png
|
| 56 |
+
results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_fungi3.txt_headed_misclassified.txt
|
| 57 |
+
results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_virus2.txt_headed_confusion_matrix.png
|
| 58 |
+
results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_fungi1.txt_headed_class_distribution.png
|
| 59 |
+
results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_virus1.txt_headed_algal.tsv
|
| 60 |
+
results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_fungi3.txt_headed_confusion_matrix.png
|
| 61 |
+
results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_archa2.txt_headed_bacterial_tagged.tsv
|
| 62 |
+
results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_archa1.txt_headed_bacterial.tsv
|
| 63 |
+
results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_fungi3.txt_headed_bacterial_tagged.tsv
|
| 64 |
+
results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_bact3.txt_headed_bacterial.tsv
|
| 65 |
+
results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_virus3.txt_headed_class_distribution.png
|
| 66 |
+
results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_fungi1.txt_headed_misclassified.txt
|
| 67 |
+
results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_virus3.txt_headed_report.txt
|
| 68 |
+
results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_archa2.txt_headed_algal_tagged.tsv
|
| 69 |
+
results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_archa3.txt_headed_class_distribution.png
|
| 70 |
+
results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_archa1.txt_headed_class_distribution.png
|
| 71 |
+
results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_virus2.txt_headed_misclassified.txt
|
| 72 |
+
results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_bact1.txt_headed_bacterial_tagged.tsv
|
| 73 |
+
results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_archa2.txt_headed_algal.tsv
|
| 74 |
+
results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_fungi1.txt_headed_algal.tsv
|
| 75 |
+
results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_virus2.txt_headed_class_distribution.png
|
| 76 |
+
results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_fungi2.txt_headed_algal_tagged.tsv
|
| 77 |
+
results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_virus1.txt_headed_bacterial_tagged.tsv
|
| 78 |
+
results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_archa3.txt_headed_bacterial.tsv
|
| 79 |
+
results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_virus1.txt_headed_report.txt
|
| 80 |
+
results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_bact3.txt_headed_algal.tsv
|
| 81 |
+
results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_archa1.txt_headed_algal.tsv
|
| 82 |
+
results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_virus2.txt_headed_metrics_by_class.png
|
| 83 |
+
results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_fungi2.txt_headed_algal.tsv
|
| 84 |
+
results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_bact1.txt_headed_algal.tsv
|
| 85 |
+
results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_fungi3.txt_headed_class_distribution.png
|
| 86 |
+
results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_archa2.txt_headed_class_distribution.png
|
| 87 |
+
results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_bact2.txt_headed_bacterial.tsv
|
| 88 |
+
results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_bact3.txt_headed_class_distribution.png
|
| 89 |
+
results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_virus3.txt_headed_algal_tagged.tsv
|
| 90 |
+
results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_fungi3.txt_headed_algal_tagged.tsv
|
| 91 |
+
results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_virus2.txt_headed_algal_tagged.tsv
|
| 92 |
+
results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_archa1.txt_headed_algal_tagged.tsv
|
| 93 |
+
results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_fungi2.txt_headed_confusion_matrix.png
|
| 94 |
+
results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_virus2.txt_headed_algal.tsv
|
| 95 |
+
results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_fungi1.txt_headed_algal_tagged.tsv
|
| 96 |
+
results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_archa3.txt_headed_report.txt
|
| 97 |
+
results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_archa2.txt_headed_report.txt
|
| 98 |
+
results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_archa3.txt_headed_algal.tsv
|
| 99 |
+
results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_archa2.txt_headed_misclassified.txt
|
| 100 |
+
results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_virus1.txt_headed_misclassified.txt
|
| 101 |
+
results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_fungi1.txt_headed_report.txt
|
| 102 |
+
results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_bact1.txt_headed_report.txt
|
| 103 |
+
results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_fungi2.txt_headed_misclassified.txt
|
| 104 |
+
results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_virus2.txt_headed_report.txt
|
| 105 |
+
results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_archa1.txt_headed_metrics_by_class.png
|
| 106 |
+
results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_bact2.txt_headed_confusion_matrix.png
|
| 107 |
+
results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_bact2.txt_headed_misclassified.txt
|
| 108 |
+
results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_archa2.txt_headed_bacterial.tsv
|
| 109 |
+
results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_fungi2.txt_headed_bacterial_tagged.tsv
|
| 110 |
+
results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_bact2.txt_headed_bacterial_tagged.tsv
|
| 111 |
+
results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_virus3.txt_headed_metrics_by_class.png
|
| 112 |
+
results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_bact2.txt_headed_class_distribution.png
|
| 113 |
+
results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_fungi2.txt_headed_bacterial.tsv
|
| 114 |
+
results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_archa3.txt_headed_bacterial_tagged.tsv
|
| 115 |
+
results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_archa1.txt_headed_report.txt
|
| 116 |
+
results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_virus2.txt_headed_bacterial_tagged.tsv
|
| 117 |
+
results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_bact1.txt_headed_class_distribution.png
|
| 118 |
+
results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_bact1.txt_headed_confusion_matrix.png
|
| 119 |
+
results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_virus3.txt_headed_confusion_matrix.png
|
| 120 |
+
results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_virus1.txt_headed_algal_tagged.tsv
|
| 121 |
+
results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_bact1.txt_headed_misclassified.txt
|
| 122 |
+
results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_fungi1.txt_headed_bacterial.tsv
|
| 123 |
+
results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_fungi3.txt_headed_metrics_by_class.png
|
| 124 |
+
results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_archa3.txt_headed_algal_tagged.tsv
|
| 125 |
+
results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_fungi1.txt_headed_bacterial_tagged.tsv
|
| 126 |
+
results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_bact3.txt_headed_report.txt
|
| 127 |
+
results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_bact3.txt_headed_bacterial_tagged.tsv
|
| 128 |
+
results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_fungi1.txt_headed_confusion_matrix.png
|
| 129 |
+
results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_fungi1.txt_headed_metrics_by_class.png
|
| 130 |
+
results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_bact2.txt_headed_metrics_by_class.png
|
| 131 |
+
results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_archa3.txt_headed_confusion_matrix.png
|
| 132 |
+
results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_virus1.txt_headed_confusion_matrix.png
|
| 133 |
+
results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_archa3.txt_headed_misclassified.txt
|
| 134 |
+
results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_bact2.txt_headed_algal_tagged.tsv
|
| 135 |
+
results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_archa1.txt_headed_confusion_matrix.png
|
| 136 |
+
results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_virus3.txt_headed_misclassified.txt
|
| 137 |
+
results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_fungi3.txt_headed_bacterial.tsv
|
| 138 |
+
results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_fungi2.txt_headed_report.txt
|
| 139 |
+
results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_bact3.txt_headed_confusion_matrix.png
|
| 140 |
+
run_la4sr_loop.sbatch
|
| 141 |
+
run_la4sr_TI-inc-algaGPT.sh
|
| 142 |
+
slurm-10718799.out
|
| 143 |
+
tar: slurm-10718799.out: file changed as we read it
|
| 144 |
+
slurm-logs/
|
| 145 |
+
slurm-logs/arrayJob_10718778_2.err
|
| 146 |
+
slurm-logs/arrayJob_10718764_10.err
|
| 147 |
+
slurm-logs/arrayJob_10718778_8.err
|
| 148 |
+
slurm-logs/arrayJob_10718778_7.err
|
| 149 |
+
slurm-logs/arrayJob_10718764_9.out
|
| 150 |
+
slurm-logs/arrayJob_10718778_11.out
|
| 151 |
+
slurm-logs/arrayJob_10718764_4.err
|
| 152 |
+
slurm-logs/arrayJob_10718778_9.err
|
| 153 |
+
slurm-logs/arrayJob_10718764_10.out
|
| 154 |
+
slurm-logs/arrayJob_10718778_6.err
|
| 155 |
+
slurm-logs/arrayJob_10718778_1.err
|
| 156 |
+
slurm-logs/arrayJob_10718764_11.out
|
| 157 |
+
slurm-logs/arrayJob_10718778_12.out
|
| 158 |
+
slurm-logs/arrayJob_10718778_9.out
|
| 159 |
+
slurm-logs/arrayJob_10718764_11.err
|
| 160 |
+
slurm-logs/arrayJob_10718778_10.out
|
| 161 |
+
slurm-logs/arrayJob_10718764_9.err
|
| 162 |
+
slurm-logs/arrayJob_10718764_2.out
|
| 163 |
+
slurm-logs/arrayJob_10718778_2.out
|
| 164 |
+
slurm-logs/arrayJob_10718778_5.out
|
| 165 |
+
slurm-logs/arrayJob_10718764_12.out
|
| 166 |
+
slurm-logs/arrayJob_10718764_1.err
|
| 167 |
+
slurm-logs/arrayJob_10718764_8.err
|
| 168 |
+
slurm-logs/arrayJob_10718764_3.err
|
| 169 |
+
slurm-logs/arrayJob_10718778_1.out
|
| 170 |
+
slurm-logs/arrayJob_10718764_2.err
|
| 171 |
+
slurm-logs/arrayJob_10718778_10.err
|
| 172 |
+
slurm-logs/arrayJob_10718778_3.err
|
| 173 |
+
slurm-logs/arrayJob_10718778_5.err
|
| 174 |
+
slurm-logs/arrayJob_10718778_8.out
|
| 175 |
+
slurm-logs/arrayJob_10718778_4.out
|
| 176 |
+
slurm-logs/arrayJob_10718764_7.err
|
| 177 |
+
slurm-logs/arrayJob_10718778_3.out
|
| 178 |
+
slurm-logs/arrayJob_10718778_12.err
|
| 179 |
+
slurm-logs/arrayJob_10718764_3.out
|
| 180 |
+
slurm-logs/arrayJob_10718764_6.out
|
| 181 |
+
slurm-logs/arrayJob_10718764_5.err
|
| 182 |
+
slurm-logs/arrayJob_10718778_7.out
|
| 183 |
+
slurm-logs/arrayJob_10718764_7.out
|
| 184 |
+
slurm-logs/arrayJob_10718764_4.out
|
| 185 |
+
slurm-logs/arrayJob_10718764_6.err
|
| 186 |
+
slurm-logs/arrayJob_10718778_4.err
|
| 187 |
+
slurm-logs/arrayJob_10718764_5.out
|
| 188 |
+
slurm-logs/arrayJob_10718778_6.out
|
| 189 |
+
slurm-logs/arrayJob_10718764_1.out
|
| 190 |
+
slurm-logs/arrayJob_10718764_8.out
|
| 191 |
+
slurm-logs/arrayJob_10718764_12.err
|
| 192 |
+
slurm-logs/arrayJob_10718778_11.err
|
| 193 |
+
targz.sh
|
targz.sh
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
#SBATCH --mem=90GB
|
| 4 |
+
#SBATCH --time=96:00:00
|
| 5 |
+
#SBATCH --cpus-per-task=12
|
| 6 |
+
|
| 7 |
+
#tar -zcvf la4sr.tar.gz la4sr_sp* run_la4sr.sh run_la4sr.sbatch run_la4sr_loop.sbatch infer-ByT5tok-attn-fastaParser.py llm-metrics-two-files.py examples/Nannochloris_eucaryotum.AAC.fa.aa.fa examples/GCF_900016285.2_ANSES_11-930-9S.fa
|
| 8 |
+
tar -zcvf la4sr-TI-inc-algaGPT.tar.gz *
|
| 9 |
+
|