GreenGenomicsLab commited on
Commit
261b39f
·
verified ·
1 Parent(s): f61a23d

Upload 30 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ la4sr_sp2.sif filter=lfs diff=lfs merge=lfs -text
README.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ to run, list your fasta files in the filelist.txt files and submit the .sbatch script, or just run run_la4sr_TI-inc-algaGPT.sh if no scheduler is available
2
+
3
+ alternatively, you can run the inference script (for raw outputs from next-token generation) and the model metrics script seperately
4
+
5
+ expected outputs from default run are in results-archive
6
+
algae-filelist.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ generated_prompts_algae1.txt_headed.fa
2
+ generated_prompts_algae2.txt_headed.fa
3
+ generated_prompts_algae3.txt_headed.fa
ckpt.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c95eb4fa7488a9f311deddda7a687d261bcc17169b76eb75dca856587b959a67
3
+ size 1037880346
contam-filelist.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ generated_prompts_archa1.txt_headed.fa
2
+ generated_prompts_archa2.txt_headed.fa
3
+ generated_prompts_archa3.txt_headed.fa
4
+ generated_prompts_bact1.txt_headed.fa
5
+ generated_prompts_bact2.txt_headed.fa
6
+ generated_prompts_bact3.txt_headed.fa
7
+ generated_prompts_fungi1.txt_headed.fa
8
+ generated_prompts_fungi2.txt_headed.fa
9
+ generated_prompts_fungi3.txt_headed.fa
10
+ generated_prompts_virus1.txt_headed.fa
11
+ generated_prompts_virus2.txt_headed.fa
12
+ generated_prompts_virus3.txt_headed.fa
filelist.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ chlorophyta_chloroplast_proteins.fasta
generated_prompts_algae1.txt_headed.fa ADDED
The diff for this file is too large to render. See raw diff
 
generated_prompts_algae2.txt_headed.fa ADDED
The diff for this file is too large to render. See raw diff
 
generated_prompts_algae3.txt_headed.fa ADDED
The diff for this file is too large to render. See raw diff
 
generated_prompts_archa1.txt_headed.fa ADDED
The diff for this file is too large to render. See raw diff
 
generated_prompts_archa2.txt_headed.fa ADDED
The diff for this file is too large to render. See raw diff
 
generated_prompts_archa3.txt_headed.fa ADDED
The diff for this file is too large to render. See raw diff
 
generated_prompts_bact1.txt_headed.fa ADDED
The diff for this file is too large to render. See raw diff
 
generated_prompts_bact2.txt_headed.fa ADDED
The diff for this file is too large to render. See raw diff
 
generated_prompts_bact3.txt_headed.fa ADDED
The diff for this file is too large to render. See raw diff
 
generated_prompts_fungi1.txt_headed.fa ADDED
The diff for this file is too large to render. See raw diff
 
generated_prompts_fungi2.txt_headed.fa ADDED
The diff for this file is too large to render. See raw diff
 
generated_prompts_fungi3.txt_headed.fa ADDED
The diff for this file is too large to render. See raw diff
 
generated_prompts_virus1.txt_headed.fa ADDED
The diff for this file is too large to render. See raw diff
 
generated_prompts_virus2.txt_headed.fa ADDED
The diff for this file is too large to render. See raw diff
 
generated_prompts_virus3.txt_headed.fa ADDED
The diff for this file is too large to render. See raw diff
 
infer_TI-inc-algaGPT.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ la4sr_infer_fasta2tsv.py — drop‑in replacement for the older LA4SR
4
+ inference utilities used by **run_la4sr_TI-inc.sh**
5
+
6
+ Changes vs. the legacy “Sample from a trained model” script
7
+ ----------------------------------------------------------
8
+ 1. **Parses a FASTA file** and feeds the collapsed sequence to the model.
9
+ 2. **Generates up to 14 tokens** per record (defaults identical to old code).
10
+ 3. **Emits a TSV** with columns: record_id, sequence, model_output.
11
+ 4. CLI knobs match the wrapper (temperature, top‑k, etc.) plus `-o`.
12
+
13
+ Python ≥3.6 compatible (removed `from __future__ import annotations`).
14
+ """
15
+
16
+ import os, sys, argparse, pickle, random
17
+ # ---------------------------------------------------------------------------
18
+ # Python < 3.7 compatibility: provide a fallback for contextlib.nullcontext
19
+ # ---------------------------------------------------------------------------
20
+ try:
21
+ from contextlib import nullcontext # Python ≥3.7
22
+ except ImportError: # Python 3.6 and older
23
+ class _NullContext:
24
+ def __init__(self, result=None):
25
+ self.result = result
26
+ def __enter__(self):
27
+ return self.result
28
+ def __exit__(self, *exc):
29
+ return False
30
+ nullcontext = _NullContext
31
+ from typing import Iterator, Tuple
32
+
33
+ import torch, tiktoken
34
+ from model import GPTConfig, GPT
35
+
36
+ ###############################################################################
37
+ # FASTA reader #
38
+ ###############################################################################
39
+
40
+ def stream_fasta(path: str) -> Iterator[Tuple[str, str]]:
41
+ """Yield (header, sequence) tuples, collapsing wrapped lines."""
42
+ header, seq_chunks = None, []
43
+ with open(path) as fh:
44
+ for line in fh:
45
+ line = line.strip()
46
+ if not line:
47
+ continue
48
+ if line.startswith('>'):
49
+ if header is not None:
50
+ yield header, ''.join(seq_chunks)
51
+ header = line[1:].split()[0]
52
+ seq_chunks = []
53
+ else:
54
+ seq_chunks.append(line)
55
+ if header is not None:
56
+ yield header, ''.join(seq_chunks)
57
+
58
+ ###############################################################################
59
+ # argument parsing #
60
+ ###############################################################################
61
+
62
+ def get_cli() -> argparse.Namespace:
63
+ p = argparse.ArgumentParser(description="LA4SR FASTA→TSV inference script")
64
+ # model/runtime-----------------------------------------------------------
65
+ p.add_argument('--init_from', default='resume',
66
+ choices=['resume','gpt2','gpt2-medium','gpt2-large'],
67
+ help='Model source; "resume" = local ckpt.pt')
68
+ p.add_argument('--out_dir', default='out',
69
+ help='Directory with ckpt.pt if --init_from resume')
70
+ p.add_argument('--device', default='cuda')
71
+ p.add_argument('--dtype', default='float16',
72
+ choices=['float32','bfloat16','float16'])
73
+ p.add_argument('--seed', type=int, default=1337)
74
+ p.add_argument('--compile', action='store_true')
75
+ # generation knobs--------------------------------------------------------
76
+ p.add_argument('--max_new_tokens', type=int, default=14)
77
+ p.add_argument('--temperature', type=float, default=0.1)
78
+ p.add_argument('--top_k', type=int, default=10)
79
+ # I/O---------------------------------------------------------------------
80
+ p.add_argument('fasta_in', help='Input FASTA')
81
+ p.add_argument('-o','--tsv_out', help='Output TSV (default: out-algaGPT/<basename>.tsv)')
82
+ return p.parse_args()
83
+
84
+ args = get_cli()
85
+
86
+ ###############################################################################
87
+ # reproducibility & autocast context #
88
+ ###############################################################################
89
+
90
+ torch.manual_seed(args.seed)
91
+ random.seed(args.seed)
92
+ if torch.cuda.is_available():
93
+ torch.cuda.manual_seed_all(args.seed)
94
+
95
+ device_type = 'cuda' if 'cuda' in args.device else 'cpu'
96
+ ptdtype_map = {'float32': torch.float32,
97
+ 'bfloat16': torch.bfloat16,
98
+ 'float16': torch.float16}
99
+ ptdtype = ptdtype_map[args.dtype]
100
+ ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
101
+
102
+ ###############################################################################
103
+ # model loading #
104
+ ###############################################################################
105
+
106
+ if args.init_from == 'resume':
107
+ ckpt_path = os.path.join(args.out_dir, 'ckpt.pt')
108
+ checkpoint = torch.load(ckpt_path, map_location=args.device)
109
+ gptconf = GPTConfig(**checkpoint['model_args'])
110
+ model = GPT(gptconf)
111
+ # strip DDP prefixes if present
112
+ state_dict = {k.replace('_orig_mod.',''):v for k,v in checkpoint['model'].items()}
113
+ model.load_state_dict(state_dict)
114
+ else:
115
+ model = GPT.from_pretrained(args.init_from, dict(dropout=0.0))
116
+
117
+ model.to(args.device).eval()
118
+ if args.compile:
119
+ model = torch.compile(model)
120
+
121
+ ###############################################################################
122
+ # encoding / decoding setup #
123
+ ###############################################################################
124
+
125
+ if args.init_from == 'resume' and 'config' in locals().get('checkpoint',{}):
126
+ cfg = checkpoint['config']
127
+ meta_path = os.path.join('data', cfg.get('dataset',''), 'meta.pkl')
128
+ else:
129
+ meta_path = ''
130
+ # ------------------------------------------------------------------
131
+ # Fallback: meta.pkl next to ckpt.pt / in --out_dir
132
+ # ------------------------------------------------------------------
133
+ if (not meta_path or not os.path.exists(meta_path)) and args.out_dir:
134
+ alt_meta = os.path.join(args.out_dir, 'meta.pkl')
135
+ if os.path.exists(alt_meta):
136
+ meta_path = alt_meta
137
+
138
+ if meta_path and os.path.exists(meta_path):
139
+ with open(meta_path,'rb') as f:
140
+ meta = pickle.load(f)
141
+ stoi, itos = meta['stoi'], meta['itos']
142
+ #encode = lambda s: [stoi.get(c, stoi['<unk>']) for c in s]
143
+ UNK_ID = stoi.get('<unk>', 0) # fall back to 0 if not present
144
+ encode = lambda s: [stoi.get(c, UNK_ID) for c in s]
145
+ decode = lambda l: ''.join(itos[i] for i in l)
146
+
147
+ else:
148
+ enc = tiktoken.get_encoding('gpt2')
149
+ encode = lambda s: enc.encode(s, allowed_special={""})
150
+ decode = lambda l: enc.decode(l)
151
+
152
+ ###############################################################################
153
+ # output path #
154
+ ###############################################################################
155
+
156
+ os.makedirs('out-algaGPT', exist_ok=True)
157
+ outfile = args.tsv_out or os.path.join('out-algaGPT', f"{os.path.splitext(os.path.basename(args.fasta_in))[0]}.tsv")
158
+
159
+ ###############################################################################
160
+ # generation loop #
161
+ ###############################################################################
162
+
163
+ with open(outfile,'w') as tsv, torch.no_grad(), ctx:
164
+ tsv.write('record_id\tsequence\tmodel_output\n')
165
+ for rid, seq in stream_fasta(args.fasta_in):
166
+ if not seq:
167
+ print(f"[WARN] empty sequence for {rid}; skipping", file=sys.stderr)
168
+ continue
169
+ x = torch.tensor(encode(seq), dtype=torch.long, device=args.device).unsqueeze(0)
170
+ try:
171
+ y = model.generate(x, args.max_new_tokens, temperature=args.temperature, top_k=args.top_k)
172
+ cont = decode(y[0].tolist())
173
+ except Exception as e:
174
+ print(f"[ERR] generation failed on {rid}: {e}", file=sys.stderr)
175
+ cont = ''
176
+ tsv.write(f"{rid}\t{seq}\t{cont}\n")
177
+
178
+ print(f"\n✓ Predictions saved to {outfile}\n")
179
+
la4sr_sp2.sif ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5b1004d7425ffc698a13ffc537d8c25fd4073ac383d5b583d01971b4616069b7
3
+ size 7201992704
la4sr_sp2.sif.md5 ADDED
@@ -0,0 +1 @@
 
 
1
+ 0ce6720b829a4eb28d87ded1301da3ca la4sr_sp2.sif
llm-metrics-two-files.py ADDED
@@ -0,0 +1,475 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ LLM Classification Metrics Generator for Two-File Analysis
4
+
5
+ This script analyzes the LLM classification results from two separate files:
6
+ - One containing algal sequences (true algal samples)
7
+ - One containing contaminant sequences (true contaminant samples)
8
+
9
+ It extracts the predicted tags and calculates comprehensive metrics.
10
+ """
11
+
12
+ import re
13
+ import sys
14
+ import argparse
15
+ import numpy as np
16
+ import matplotlib.pyplot as plt
17
+ from sklearn.metrics import precision_recall_fscore_support, accuracy_score, confusion_matrix
18
+ from sklearn.metrics import classification_report
19
+
20
+ def parse_files(algal_file, contaminant_file):
21
+ """
22
+ Parse the algal and contaminant files to extract true and predicted labels
23
+
24
+ Arguments:
25
+ algal_file (str): Path to the file containing algal sequences
26
+ contaminant_file (str): Path to the file containing contaminant sequences
27
+
28
+ Returns:
29
+ tuple: Lists of true labels and predicted labels
30
+ """
31
+ true_labels = []
32
+ predicted_labels = []
33
+ sequence_ids = []
34
+
35
+ # Process algal file (all true labels are 'algal')
36
+ with open(algal_file, 'r') as f:
37
+ for line in f:
38
+ line = line.strip()
39
+ if not line:
40
+ continue
41
+
42
+ # Skip header or non-data lines
43
+ if line.startswith('==>') or line.startswith('('): ## or not re.search(r'-|_', line):
44
+ continue
45
+
46
+ # Extract sequence ID
47
+ seq_id_match = re.match(r'^([^\s]+)', line)
48
+ if seq_id_match:
49
+ seq_id = seq_id_match.group(1)
50
+ else:
51
+ seq_id = "unknown_id"
52
+
53
+ # Add to tracking lists
54
+ true_labels.append('algal')
55
+ sequence_ids.append(seq_id)
56
+
57
+ # Determine predicted label based on tags
58
+ if '@' in line:
59
+ predicted_labels.append('algal')
60
+ elif '!' in line:
61
+ predicted_labels.append('contaminant')
62
+ else:
63
+ predicted_labels.append('unknown')
64
+ #if re.search(r'<@+>', line):
65
+ # predicted_labels.append('algal')
66
+ #elif re.search(r'<!+>', line):
67
+ # predicted_labels.append('contaminant')
68
+ #else:
69
+ # predicted_labels.append('unknown')
70
+
71
+ # Process contaminant file (all true labels are 'contaminant')
72
+ with open(contaminant_file, 'r') as f:
73
+ for line in f:
74
+ line = line.strip()
75
+ if not line:
76
+ continue
77
+
78
+ # Skip header or non-data lines
79
+ if line.startswith('==>') or line.startswith('('): ## or not re.search(r'\.|_', line):
80
+ continue
81
+
82
+ # Extract sequence ID
83
+ seq_id_match = re.match(r'^([^\s]+)', line)
84
+ if seq_id_match:
85
+ seq_id = seq_id_match.group(1)
86
+ else:
87
+ seq_id = "unknown_id"
88
+
89
+ # Add to tracking lists
90
+ true_labels.append('contaminant')
91
+ sequence_ids.append(seq_id)
92
+
93
+ # Determine predicted label based on tags
94
+ # if re.search(r'<@+>', line):
95
+ # predicted_labels.append('algal')
96
+ #elif re.search(r'<!+>', line):
97
+ # predicted_labels.append('contaminant')
98
+ #e#lse:
99
+ # predicted_labels.append('unknown')
100
+ # Determine predicted label based on symbols (@ for algal, ! for contaminant)
101
+ if '@' in line:
102
+ predicted_labels.append('algal')
103
+ elif '!' in line:
104
+ predicted_labels.append('contaminant')
105
+ else:
106
+ predicted_labels.append('unknown')
107
+
108
+ return true_labels, predicted_labels, sequence_ids
109
+
110
+ def calculate_metrics(true_labels, predicted_labels):
111
+ """
112
+ Calculate comprehensive classification metrics
113
+
114
+ Arguments:
115
+ true_labels (list): List of true class labels
116
+ predicted_labels (list): List of predicted class labels
117
+
118
+ Returns:
119
+ dict: Dictionary containing all calculated metrics
120
+ """
121
+ # Convert labels for sklearn functions
122
+ classes = ['algal', 'contaminant']
123
+ label_map = {label: i for i, label in enumerate(classes)}
124
+
125
+ # Convert to numeric form
126
+ true_numeric = np.array([label_map.get(label, 2) for label in true_labels])
127
+ pred_numeric = np.array([label_map.get(label, 2) for label in predicted_labels])
128
+
129
+ # Filter out unknowns for main metrics
130
+ known_indices = [i for i, pred in enumerate(predicted_labels) if pred != 'unknown']
131
+ true_known = [true_labels[i] for i in known_indices]
132
+ pred_known = [predicted_labels[i] for i in known_indices]
133
+
134
+ # Overall accuracy (including unknowns as wrong predictions)
135
+ accuracy = sum(t == p for t, p in zip(true_labels, predicted_labels)) / len(true_labels)
136
+
137
+ if true_known and pred_known:
138
+ # Convert to numeric
139
+ true_known_numeric = np.array([label_map[label] for label in true_known])
140
+ pred_known_numeric = np.array([label_map[label] for label in pred_known])
141
+
142
+ # Calculate precision, recall, and F1 (excluding unknowns)
143
+ precision, recall, f1, support = precision_recall_fscore_support(
144
+ true_known_numeric,
145
+ pred_known_numeric,
146
+ labels=[0, 1], # algal, contaminant
147
+ zero_division=0
148
+ )
149
+
150
+ # Create confusion matrix
151
+ cm = confusion_matrix(
152
+ true_known_numeric,
153
+ pred_known_numeric,
154
+ labels=[0, 1]
155
+ )
156
+
157
+ # Full classification report
158
+ report = classification_report(
159
+ true_known_numeric,
160
+ pred_known_numeric,
161
+ labels=[0, 1],
162
+ target_names=classes,
163
+ output_dict=True
164
+ )
165
+ else:
166
+ precision = recall = f1 = support = [0, 0]
167
+ cm = np.zeros((2, 2))
168
+ report = {}
169
+
170
+ # Count occurrences and calculate per-class metrics
171
+ class_metrics = {}
172
+ for class_name in classes:
173
+ class_indices = [i for i, label in enumerate(true_labels) if label == class_name]
174
+ total = len(class_indices)
175
+
176
+ if total == 0:
177
+ class_metrics[class_name] = {
178
+ "total": 0,
179
+ "correct": 0,
180
+ "incorrect": 0,
181
+ "unknown": 0,
182
+ "accuracy": 0,
183
+ "error_rate": 0
184
+ }
185
+ continue
186
+
187
+ correct = sum(1 for i in class_indices if predicted_labels[i] == class_name)
188
+ unknown = sum(1 for i in class_indices if predicted_labels[i] == "unknown")
189
+ incorrect = total - correct - unknown
190
+
191
+ class_metrics[class_name] = {
192
+ "total": total,
193
+ "correct": correct,
194
+ "incorrect": incorrect,
195
+ "unknown": unknown,
196
+ "accuracy": correct / total if total > 0 else 0,
197
+ "error_rate": (incorrect + unknown) / total if total > 0 else 0
198
+ }
199
+
200
+ # Compile all metrics
201
+ metrics = {
202
+ "accuracy": accuracy,
203
+ "class_metrics": class_metrics,
204
+ "confusion_matrix": cm,
205
+ "precision": {classes[i]: precision[i] for i in range(len(classes))},
206
+ "recall": {classes[i]: recall[i] for i in range(len(classes))},
207
+ "f1": {classes[i]: f1[i] for i in range(len(classes))},
208
+ "support": {classes[i]: support[i] for i in range(len(classes))},
209
+ "classification_report": report,
210
+ "macro_f1": np.mean(f1),
211
+ "weighted_f1": np.sum(f1 * support) / np.sum(support) if np.sum(support) > 0 else 0,
212
+ "total_samples": len(true_labels),
213
+ "total_correct": sum(t == p for t, p in zip(true_labels, predicted_labels)),
214
+ "total_unknown": predicted_labels.count("unknown")
215
+ }
216
+
217
+ return metrics
218
+
219
+ def display_results(metrics, output_file=None):
220
+ """
221
+ Display comprehensive results and optionally save to file
222
+
223
+ Arguments:
224
+ metrics (dict): Dictionary containing all calculated metrics
225
+ output_file (str, optional): Path to save results to
226
+ """
227
+ # Start capturing output if needed
228
+ if output_file:
229
+ import io
230
+ output_capture = io.StringIO()
231
+ original_stdout = sys.stdout
232
+ sys.stdout = output_capture
233
+
234
+ # Print header
235
+ print("\n" + "="*60)
236
+ print(" LLM CLASSIFICATION METRICS REPORT")
237
+ print("="*60)
238
+
239
+ # Overall metrics
240
+ print("\n=== OVERALL METRICS ===")
241
+ print(f"Total samples: {metrics['total_samples']}")
242
+ print(f"Correctly classified: {metrics['total_correct']} ({metrics['total_correct']/metrics['total_samples']*100:.2f}%)")
243
+ print(f"Unknown predictions: {metrics['total_unknown']} ({metrics['total_unknown']/metrics['total_samples']*100:.2f}%)")
244
+ print(f"Overall accuracy: {metrics['accuracy']:.4f}")
245
+ print(f"Macro F1: {metrics['macro_f1']:.4f}")
246
+ print(f"Weighted F1: {metrics['weighted_f1']:.4f}")
247
+
248
+ # Confusion matrix
249
+ cm = metrics["confusion_matrix"]
250
+ class_labels = ["Algal", "Bacterial"]
251
+
252
+ print("\n=== CONFUSION MATRIX ===")
253
+ print(f"{'':15} | {'Predicted Algal':15} | {'Predicted Bacterial':20}")
254
+ print("-" * 55)
255
+ for i, label in enumerate(class_labels):
256
+ print(f"{label:15} | {int(cm[i][0]):15} | {int(cm[i][1]):20}")
257
+
258
+ # Per-class metrics
259
+ print("\n=== PER-CLASS METRICS ===")
260
+ print(f"{'Class':10} | {'Precision':10} | {'Recall':10} | {'F1 Score':10} | {'Support':10}")
261
+ print("-" * 60)
262
+ for class_name in ['algal', 'contaminant']:
263
+ precision = metrics['precision'][class_name]
264
+ recall = metrics['recall'][class_name]
265
+ f1 = metrics['f1'][class_name]
266
+ support = metrics['support'][class_name]
267
+ print(f"{class_name.capitalize():10} | {precision:.4f} | {recall:.4f} | {f1:.4f} | {int(support):10}")
268
+
269
+ # Detailed class counts
270
+ print("\n=== DETAILED CLASS COUNTS ===")
271
+ for class_name, class_data in metrics["class_metrics"].items():
272
+ print(f"{class_name.capitalize()} class:")
273
+ print(f" Total samples: {class_data['total']}")
274
+ if class_data['total'] > 0:
275
+ print(f" Correctly classified: {class_data['correct']} ({class_data['correct']/class_data['total']*100:.2f}%)")
276
+ print(f" Incorrectly classified: {class_data['incorrect']} ({class_data['incorrect']/class_data['total']*100:.2f}%)")
277
+ print(f" Unknown: {class_data['unknown']} ({class_data['unknown']/class_data['total']*100:.2f}%)")
278
+ print()
279
+
280
+ # If saving to file
281
+ if output_file:
282
+ # Restore stdout
283
+ sys.stdout = original_stdout
284
+
285
+ # Write to file
286
+ with open(output_file, 'w') as f:
287
+ f.write(output_capture.getvalue())
288
+
289
+ print(f"Results saved to {output_file}")
290
+
291
+ def generate_visualizations(metrics, output_prefix=None):
292
+ """
293
+ Generate visualizations of the metrics
294
+
295
+ Arguments:
296
+ metrics (dict): Dictionary containing all calculated metrics
297
+ output_prefix (str, optional): Prefix for output image files
298
+ """
299
+ # Create confusion matrix heatmap
300
+ plt.figure(figsize=(8, 6))
301
+ cm = metrics["confusion_matrix"]
302
+ plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
303
+ plt.title('Confusion Matrix')
304
+ plt.colorbar()
305
+
306
+ classes = ["Algal", "Bacterial"]
307
+ tick_marks = np.arange(len(classes))
308
+ plt.xticks(tick_marks, classes, rotation=45)
309
+ plt.yticks(tick_marks, classes)
310
+
311
+ # Add text annotations
312
+ thresh = cm.max() / 2.0
313
+ for i in range(cm.shape[0]):
314
+ for j in range(cm.shape[1]):
315
+ plt.text(j, i, format(int(cm[i, j]), 'd'),
316
+ horizontalalignment="center",
317
+ color="white" if cm[i, j] > thresh else "black")
318
+
319
+ plt.ylabel('True label')
320
+ plt.xlabel('Predicted label')
321
+ plt.tight_layout()
322
+
323
+ if output_prefix:
324
+ plt.savefig(f"{output_prefix}_confusion_matrix.png", dpi=300, bbox_inches='tight')
325
+ else:
326
+ plt.show()
327
+
328
+ # Create per-class metrics bar chart
329
+ plt.figure(figsize=(10, 6))
330
+
331
+ metrics_names = ['Precision', 'Recall', 'F1-Score']
332
+ x = np.arange(len(metrics_names))
333
+ width = 0.35
334
+
335
+ algal_values = [metrics['precision']['algal'], metrics['recall']['algal'], metrics['f1']['algal']]
336
+ contaminant_values = [metrics['precision']['contaminant'], metrics['recall']['contaminant'], metrics['f1']['contaminant']]
337
+
338
+ plt.bar(x - width/2, algal_values, width, label='Algal')
339
+ plt.bar(x + width/2, contaminant_values, width, label='Bacterial')
340
+
341
+ plt.ylabel('Score')
342
+ plt.title('Performance Metrics by Class')
343
+ plt.xticks(x, metrics_names)
344
+ plt.ylim(0, 1.1)
345
+ plt.legend()
346
+ plt.grid(axis='y', linestyle='--', alpha=0.7)
347
+
348
+ if output_prefix:
349
+ plt.savefig(f"{output_prefix}_metrics_by_class.png", dpi=300, bbox_inches='tight')
350
+ else:
351
+ plt.show()
352
+
353
+ # Create class distribution pie charts
354
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
355
+
356
+ # Algal class distribution
357
+ algal_data = metrics['class_metrics']['algal']
358
+ algal_labels = ['Correct', 'Incorrect', 'Unknown']
359
+ algal_values = [algal_data['correct'], algal_data['incorrect'], algal_data['unknown']]
360
+ ax1.pie(algal_values, labels=algal_labels, autopct='%1.1f%%', startangle=90)
361
+ ax1.set_title('Algal Class Predictions')
362
+
363
+ # Bacterial class distribution
364
+ contaminant_data = metrics['class_metrics']['contaminant']
365
+ contaminant_labels = ['Correct', 'Incorrect', 'Unknown']
366
+ contaminant_values = [contaminant_data['correct'], contaminant_data['incorrect'], contaminant_data['unknown']]
367
+ ax2.pie(contaminant_values, labels=contaminant_labels, autopct='%1.1f%%', startangle=90)
368
+ ax2.set_title('Bacterial Class Predictions')
369
+
370
+ plt.tight_layout()
371
+
372
+ if output_prefix:
373
+ plt.savefig(f"{output_prefix}_class_distribution.png", dpi=300, bbox_inches='tight')
374
+ else:
375
+ plt.show()
376
+
377
+ def create_misclassified_report(true_labels, predicted_labels, sequence_ids, output_file=None):
378
+ """
379
+ Create a report of misclassified sequences
380
+
381
+ Arguments:
382
+ true_labels (list): List of true class labels
383
+ predicted_labels (list): List of predicted class labels
384
+ sequence_ids (list): List of sequence IDs
385
+ output_file (str, optional): Path to save the report to
386
+ """
387
+ misclassified = []
388
+ for i, (true, pred, seq_id) in enumerate(zip(true_labels, predicted_labels, sequence_ids)):
389
+ if true != pred:
390
+ misclassified.append({
391
+ 'id': seq_id,
392
+ 'true': true,
393
+ 'predicted': pred
394
+ })
395
+
396
+ # Start capturing output
397
+ if output_file:
398
+ import io
399
+ output_capture = io.StringIO()
400
+ original_stdout = sys.stdout
401
+ sys.stdout = output_capture
402
+
403
+ # Print header
404
+ print("\n" + "="*60)
405
+ print(" MISCLASSIFIED SEQUENCES REPORT")
406
+ print("="*60)
407
+ print(f"\nTotal misclassified: {len(misclassified)} out of {len(true_labels)} ({len(misclassified)/len(true_labels)*100:.2f}%)\n")
408
+
409
+ # Print algal sequences misclassified as contaminant
410
+ print("\n--- ALGAL SEQUENCES MISCLASSIFIED AS BACTERIAL ---")
411
+ algal_as_contaminant = [m for m in misclassified if m['true'] == 'algal' and m['predicted'] == 'contaminant']
412
+ for item in algal_as_contaminant:
413
+ print(f"ID: {item['id']}")
414
+ print(f"Total: {len(algal_as_contaminant)}")
415
+
416
+ # Print contaminant sequences misclassified as algal
417
+ print("\n--- BACTERIAL SEQUENCES MISCLASSIFIED AS ALGAL ---")
418
+ contaminant_as_algal = [m for m in misclassified if m['true'] == 'contaminant' and m['predicted'] == 'algal']
419
+ for item in contaminant_as_algal:
420
+ print(f"ID: {item['id']}")
421
+ print(f"Total: {len(contaminant_as_algal)}")
422
+
423
+ # Print unknown classifications
424
+ print("\n--- SEQUENCES WITH UNKNOWN CLASSIFICATION ---")
425
+ unknown = [m for m in misclassified if m['predicted'] == 'unknown']
426
+ for item in unknown:
427
+ print(f"ID: {item['id']} (True: {item['true']})")
428
+ print(f"Total: {len(unknown)}")
429
+
430
+ # If saving to file
431
+ if output_file:
432
+ # Restore stdout
433
+ sys.stdout = original_stdout
434
+
435
+ # Write to file
436
+ with open(output_file, 'w') as f:
437
+ f.write(output_capture.getvalue())
438
+
439
+ print(f"Misclassified report saved to {output_file}")
440
+
441
+ def main():
442
+ """Main function to run the script"""
443
+ parser = argparse.ArgumentParser(description='LLM Classification Metrics Generator for Two-File Analysis')
444
+ parser.add_argument('algal_file', help='Path to the file containing algal sequences')
445
+ parser.add_argument('contaminant_file', help='Path to the file containing contaminant sequences')
446
+ parser.add_argument('-o', '--output', help='Path to save the metrics report')
447
+ parser.add_argument('-m', '--misclassified', help='Path to save the misclassified sequences report')
448
+ parser.add_argument('-v', '--visualize', action='store_true', help='Generate visualizations')
449
+ parser.add_argument('-p', '--prefix', default='llm_metrics', help='Prefix for output files')
450
+
451
+ args = parser.parse_args()
452
+
453
+ # Parse files and calculate metrics
454
+ true_labels, predicted_labels, sequence_ids = parse_files(args.algal_file, args.contaminant_file)
455
+ metrics = calculate_metrics(true_labels, predicted_labels)
456
+
457
+ # Display results
458
+ output_file = f"{args.prefix}_report.txt" if args.output else None
459
+ display_results(metrics, output_file)
460
+
461
+ # Generate visualizations if requested
462
+ if args.visualize:
463
+ generate_visualizations(metrics, args.prefix)
464
+
465
+ # Create misclassified report if requested
466
+ if args.misclassified:
467
+ misclassified_file = f"{args.prefix}_misclassified.txt" if args.misclassified is True else args.misclassified
468
+ create_misclassified_report(true_labels, predicted_labels, sequence_ids, misclassified_file)
469
+
470
+ # Return number of misclassifications (for automated testing)
471
+ misclassifications = sum(t != p for t, p in zip(true_labels, predicted_labels))
472
+ return misclassifications
473
+
474
+ if __name__ == "__main__":
475
+ sys.exit(main())
meta.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:75049dcc1458196491b007a46f4aa1acef9ca3c27df282f226af572877203752
3
+ size 14858
model.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Full definition of a GPT Language Model, all of it in this single file.
3
+ References:
4
+ 1) the official GPT-2 TensorFlow implementation released by OpenAI:
5
+ https://github.com/openai/gpt-2/blob/master/src/model.py
6
+ 2) huggingface/transformers PyTorch implementation:
7
+ https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py
8
+ """
9
+
10
+ import math
11
+ import inspect
12
+ from dataclasses import dataclass
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ from torch.nn import functional as F
17
+
18
+ class LayerNorm(nn.Module):
19
+ """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """
20
+
21
+ def __init__(self, ndim, bias):
22
+ super().__init__()
23
+ self.weight = nn.Parameter(torch.ones(ndim))
24
+ self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
25
+
26
+ def forward(self, input):
27
+ return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
28
+
29
+ class CausalSelfAttention(nn.Module):
30
+
31
+ def __init__(self, config):
32
+ super().__init__()
33
+ assert config.n_embd % config.n_head == 0
34
+ # key, query, value projections for all heads, but in a batch
35
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
36
+ # output projection
37
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
38
+ # regularization
39
+ self.attn_dropout = nn.Dropout(config.dropout)
40
+ self.resid_dropout = nn.Dropout(config.dropout)
41
+ self.n_head = config.n_head
42
+ self.n_embd = config.n_embd
43
+ self.dropout = config.dropout
44
+ # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
45
+ self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
46
+ if not self.flash:
47
+ print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")
48
+ # causal mask to ensure that attention is only applied to the left in the input sequence
49
+ self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
50
+ .view(1, 1, config.block_size, config.block_size))
51
+
52
+ def forward(self, x):
53
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
54
+
55
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
56
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
57
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
58
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
59
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
60
+
61
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
62
+ if self.flash:
63
+ # efficient attention using Flash Attention CUDA kernels
64
+ y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)
65
+ else:
66
+ # manual implementation of attention
67
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
68
+ att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
69
+ att = F.softmax(att, dim=-1)
70
+ att = self.attn_dropout(att)
71
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
72
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
73
+
74
+ # output projection
75
+ y = self.resid_dropout(self.c_proj(y))
76
+ return y
77
+
78
+ class MLP(nn.Module):
79
+
80
+ def __init__(self, config):
81
+ super().__init__()
82
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
83
+ self.gelu = nn.GELU()
84
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
85
+ self.dropout = nn.Dropout(config.dropout)
86
+
87
+ def forward(self, x):
88
+ x = self.c_fc(x)
89
+ x = self.gelu(x)
90
+ x = self.c_proj(x)
91
+ x = self.dropout(x)
92
+ return x
93
+
94
+ class Block(nn.Module):
95
+
96
+ def __init__(self, config):
97
+ super().__init__()
98
+ self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
99
+ self.attn = CausalSelfAttention(config)
100
+ self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
101
+ self.mlp = MLP(config)
102
+
103
+ def forward(self, x):
104
+ x = x + self.attn(self.ln_1(x))
105
+ x = x + self.mlp(self.ln_2(x))
106
+ return x
107
+
108
+ @dataclass
109
+ class GPTConfig:
110
+ block_size: int = 1024
111
+ vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
112
+ n_layer: int = 12
113
+ n_head: int = 12
114
+ n_embd: int = 768
115
+ dropout: float = 0.0
116
+ bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
117
+
118
+ class GPT(nn.Module):
119
+
120
+ def __init__(self, config):
121
+ super().__init__()
122
+ assert config.vocab_size is not None
123
+ assert config.block_size is not None
124
+ self.config = config
125
+
126
+ self.transformer = nn.ModuleDict(dict(
127
+ wte = nn.Embedding(config.vocab_size, config.n_embd),
128
+ wpe = nn.Embedding(config.block_size, config.n_embd),
129
+ drop = nn.Dropout(config.dropout),
130
+ h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
131
+ ln_f = LayerNorm(config.n_embd, bias=config.bias),
132
+ ))
133
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
134
+ # with weight tying when using torch.compile() some warnings get generated:
135
+ # "UserWarning: functional_call was passed multiple values for tied weights.
136
+ # This behavior is deprecated and will be an error in future versions"
137
+ # not 100% sure what this is, so far seems to be harmless. TODO investigate
138
+ self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying
139
+
140
+ # init all weights
141
+ self.apply(self._init_weights)
142
+ # apply special scaled init to the residual projections, per GPT-2 paper
143
+ for pn, p in self.named_parameters():
144
+ if pn.endswith('c_proj.weight'):
145
+ torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))
146
+
147
+ # report number of parameters
148
+ print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))
149
+
150
+ def get_num_params(self, non_embedding=True):
151
+ """
152
+ Return the number of parameters in the model.
153
+ For non-embedding count (default), the position embeddings get subtracted.
154
+ The token embeddings would too, except due to the parameter sharing these
155
+ params are actually used as weights in the final layer, so we include them.
156
+ """
157
+ n_params = sum(p.numel() for p in self.parameters())
158
+ if non_embedding:
159
+ n_params -= self.transformer.wpe.weight.numel()
160
+ return n_params
161
+
162
+ def _init_weights(self, module):
163
+ if isinstance(module, nn.Linear):
164
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
165
+ if module.bias is not None:
166
+ torch.nn.init.zeros_(module.bias)
167
+ elif isinstance(module, nn.Embedding):
168
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
169
+
170
+ def forward(self, idx, targets=None):
171
+ device = idx.device
172
+ b, t = idx.size()
173
+ assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
174
+ pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)
175
+
176
+ # forward the GPT model itself
177
+ tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
178
+ pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)
179
+ x = self.transformer.drop(tok_emb + pos_emb)
180
+ for block in self.transformer.h:
181
+ x = block(x)
182
+ x = self.transformer.ln_f(x)
183
+
184
+ if targets is not None:
185
+ # if we are given some desired targets also calculate the loss
186
+ logits = self.lm_head(x)
187
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
188
+ else:
189
+ # inference-time mini-optimization: only forward the lm_head on the very last position
190
+ logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
191
+ loss = None
192
+
193
+ return logits, loss
194
+
195
+ def crop_block_size(self, block_size):
196
+ # model surgery to decrease the block size if necessary
197
+ # e.g. we may load the GPT2 pretrained model checkpoint (block size 1024)
198
+ # but want to use a smaller block size for some smaller, simpler model
199
+ assert block_size <= self.config.block_size
200
+ self.config.block_size = block_size
201
+ self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size])
202
+ for block in self.transformer.h:
203
+ if hasattr(block.attn, 'bias'):
204
+ block.attn.bias = block.attn.bias[:,:,:block_size,:block_size]
205
+
206
+ @classmethod
207
+ def from_pretrained(cls, model_type, override_args=None):
208
+ assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}
209
+ override_args = override_args or {} # default to empty dict
210
+ # only dropout can be overridden see more notes below
211
+ assert all(k == 'dropout' for k in override_args)
212
+ from transformers import GPT2LMHeadModel
213
+ print("loading weights from pretrained gpt: %s" % model_type)
214
+
215
+ # n_layer, n_head and n_embd are determined from model_type
216
+ config_args = {
217
+ 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params
218
+ 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params
219
+ 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params
220
+ 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params
221
+ }[model_type]
222
+ print("forcing vocab_size=50257, block_size=1024, bias=True")
223
+ config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints
224
+ config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints
225
+ config_args['bias'] = True # always True for GPT model checkpoints
226
+ # we can override the dropout rate, if desired
227
+ if 'dropout' in override_args:
228
+ print(f"overriding dropout rate to {override_args['dropout']}")
229
+ config_args['dropout'] = override_args['dropout']
230
+ # create a from-scratch initialized minGPT model
231
+ config = GPTConfig(**config_args)
232
+ model = GPT(config)
233
+ sd = model.state_dict()
234
+ sd_keys = sd.keys()
235
+ sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param
236
+
237
+ # init a huggingface/transformers model
238
+ model_hf = GPT2LMHeadModel.from_pretrained(model_type)
239
+ sd_hf = model_hf.state_dict()
240
+
241
+ # copy while ensuring all of the parameters are aligned and match in names and shapes
242
+ sd_keys_hf = sd_hf.keys()
243
+ sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer
244
+ sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer)
245
+ transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
246
+ # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear
247
+ # this means that we have to transpose these weights when we import them
248
+ assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}"
249
+ for k in sd_keys_hf:
250
+ if any(k.endswith(w) for w in transposed):
251
+ # special treatment for the Conv1D weights we need to transpose
252
+ assert sd_hf[k].shape[::-1] == sd[k].shape
253
+ with torch.no_grad():
254
+ sd[k].copy_(sd_hf[k].t())
255
+ else:
256
+ # vanilla copy over the other parameters
257
+ assert sd_hf[k].shape == sd[k].shape
258
+ with torch.no_grad():
259
+ sd[k].copy_(sd_hf[k])
260
+
261
+ return model
262
+
263
+ def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
264
+ # start with all of the candidate parameters
265
+ param_dict = {pn: p for pn, p in self.named_parameters()}
266
+ # filter out those that do not require grad
267
+ param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
268
+ # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
269
+ # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
270
+ decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
271
+ nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
272
+ optim_groups = [
273
+ {'params': decay_params, 'weight_decay': weight_decay},
274
+ {'params': nodecay_params, 'weight_decay': 0.0}
275
+ ]
276
+ num_decay_params = sum(p.numel() for p in decay_params)
277
+ num_nodecay_params = sum(p.numel() for p in nodecay_params)
278
+ print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
279
+ print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
280
+ # Create AdamW optimizer and use the fused version if it is available
281
+ fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
282
+ use_fused = fused_available and device_type == 'cuda'
283
+ extra_args = dict(fused=True) if use_fused else dict()
284
+ optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
285
+ print(f"using fused AdamW: {use_fused}")
286
+
287
+ return optimizer
288
+
289
+ def estimate_mfu(self, fwdbwd_per_iter, dt):
290
+ """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """
291
+ # first estimate the number of flops we do per iteration.
292
+ # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311
293
+ N = self.get_num_params()
294
+ cfg = self.config
295
+ L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size
296
+ flops_per_token = 6*N + 12*L*H*Q*T
297
+ flops_per_fwdbwd = flops_per_token * T
298
+ flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
299
+ # express our flops throughput as ratio of A100 bfloat16 peak flops
300
+ flops_achieved = flops_per_iter * (1.0/dt) # per second
301
+ flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS
302
+ mfu = flops_achieved / flops_promised
303
+ return mfu
304
+
305
+ @torch.no_grad()
306
+ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
307
+ """
308
+ Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
309
+ the sequence max_new_tokens times, feeding the predictions back into the model each time.
310
+ Most likely you'll want to make sure to be in model.eval() mode of operation for this.
311
+ """
312
+ for _ in range(max_new_tokens):
313
+ # if the sequence context is growing too long we must crop it at block_size
314
+ idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]
315
+ # forward the model to get the logits for the index in the sequence
316
+ logits, _ = self(idx_cond)
317
+ # pluck the logits at the final step and scale by desired temperature
318
+ logits = logits[:, -1, :] / temperature
319
+ # optionally crop the logits to only the top k options
320
+ if top_k is not None:
321
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
322
+ logits[logits < v[:, [-1]]] = -float('Inf')
323
+ # apply softmax to convert logits to (normalized) probabilities
324
+ probs = F.softmax(logits, dim=-1)
325
+ # sample from the distribution
326
+ idx_next = torch.multinomial(probs, num_samples=1)
327
+ # append sampled index to the running sequence and continue
328
+ idx = torch.cat((idx, idx_next), dim=1)
329
+
330
+ return idx
run_la4sr_TI-inc-algaGPT.sh ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+ ###############################################################################
4
+ # run_la4sr_TI-inc.sh — LA4SR pipeline: model inference (FASTA→TSV) + metrics
5
+ ###############################################################################
6
+
7
+ # ---------------------- Configuration ----------------------
8
+ SCRIPT_DIR="$(pwd)"
9
+ INFER_SCRIPT="$SCRIPT_DIR/infer_TI-inc-algaGPT.py" # <— new script
10
+ METRICS_SCRIPT="$SCRIPT_DIR/llm-metrics-two-files.py"
11
+ SIF="$SCRIPT_DIR/la4sr_sp2.sif"
12
+
13
+ # cache for HF tokenizers & models
14
+ tcache="$SCRIPT_DIR/cache"
15
+ mkdir -p "$tcache"
16
+
17
+ # ---------------------- Usage ----------------------
18
+ if [[ $# -ne 3 ]]; then
19
+ cat <<EOF
20
+ Usage: $(basename "$0") <model_name|resume> <algal_fasta> <bacterial_fasta>
21
+
22
+ If you pass the literal word resume the script loads ckpt.pt + meta.pkl
23
+ from the current directory. Otherwise the value is forwarded to --init_from
24
+ (e.g. GreenGenomicsLab/LA4SR-gpt-neo125-ALMGA-FL).
25
+
26
+ EOF
27
+ exit 1
28
+ fi
29
+
30
+ MODEL_NAME="$1" # "resume" OR HF repo / local path
31
+ algal_fasta="$2"
32
+ bact_fasta="$3"
33
+
34
+ prefix="$(basename "${algal_fasta%.*}")_vs_$(basename "${bact_fasta%.*}")"
35
+ mkdir -p results
36
+ alg_out="results/${prefix}_algal.tsv"
37
+ bac_out="results/${prefix}_bacterial.tsv"
38
+ alg_out_tagged="results/${prefix}_algal_tagged.tsv"
39
+ bac_out_tagged="results/${prefix}_bacterial_tagged.tsv"
40
+ report="results/${prefix}_report.txt"
41
+ miscl="results/${prefix}_misclassified.txt"
42
+
43
+ # ---------------------- Inference ----------------------
44
+
45
+ run_infer () {
46
+ local fasta=$1 out=$2
47
+ echo -e "\n→ Inferring $(basename "$fasta")..."
48
+
49
+ # Build common args
50
+ PY_ARGS=( --init_from "$MODEL_NAME" )
51
+ [[ "$MODEL_NAME" == "resume" ]] && PY_ARGS+=( --out_dir /workdir )
52
+
53
+ singularity exec --nv \
54
+ -B "$fasta:/input.fasta" \
55
+ -B "$(pwd):/workdir" \
56
+ -B "$tcache:$tcache" \
57
+ --env TRANSFORMERS_CACHE="$tcache" \
58
+ "$SIF" \
59
+ bash -c 'cd /workdir && \
60
+ python3 infer_TI-inc-algaGPT.py '"${PY_ARGS[*]}"' /input.fasta -o "'"$out"'"'
61
+ }
62
+
63
+ #run_infer () {
64
+ #local fasta=$1 out=$2
65
+ #echo -e "\n→ Inferring $(basename "$fasta")..."
66
+
67
+ #PY_ARGS=( --init_from "$MODEL_NAME" )
68
+ #[[ "$MODEL_NAME" == "resume" ]] && PY_ARGS+=( --out_dir /workdir )
69
+
70
+ # singularity exec --nv \
71
+ # -B "$fasta:/input.fasta" \
72
+ # -B "$(pwd):/workdir" \ # <— whole project goes in
73
+ # -B "$tcache:$tcache" \
74
+ # --env TRANSFORMERS_CACHE="$tcache" \
75
+ # "$SIF" \
76
+ # bash -c "cd /workdir && \
77
+ # python3 "$INFER_SCRIPT" \
78
+ # \"${PY_ARGS[@]}\" /input.fasta -o \"$out\""
79
+ #}
80
+
81
+ #run_infer () {
82
+ # local fasta=$1 out=$2
83
+ #echo -e "\n→ Inferring $(basename "$fasta")..."
84
+
85
+ # build python arg list: --init_from ... [--out_dir PWD]
86
+ #PY_ARGS=( --init_from "$MODEL_NAME" )
87
+ #[[ "$MODEL_NAME" == "resume" ]] && PY_ARGS+=( --out_dir "$SCRIPT_DIR" )
88
+
89
+ #if [[ -f "$SIF" ]]; then
90
+ # singularity exec --nv \
91
+ # -B "$fasta:/input.fasta" \
92
+ # -B "$INFER_SCRIPT:/infer.py" \
93
+ #-B "$(pwd):/workdir" \
94
+ #-B "$tcache:$tcache" \
95
+ #--env TRANSFORMERS_CACHE="$tcache" \
96
+ #"$SIF" \
97
+ #python3 /infer.py "${PY_ARGS[@]}" /input.fasta \
98
+ # -o "/workdir/$out"
99
+ #else
100
+ # TRANSFORMERS_CACHE="$tcache" \
101
+ #python3 "$INFER_SCRIPT" "${PY_ARGS[@]}" "$fasta" \
102
+ # -o "$out"
103
+ #fi
104
+
105
+ #echo " ✔ Wrote $out"
106
+ #}
107
+
108
+ run_infer "$algal_fasta" "$alg_out"
109
+ run_infer "$bact_fasta" "$bac_out"
110
+
111
+ # ---------------------- Post-process Tags ----------------------
112
+ convert_tags () {
113
+ local infile=$1 outfile=$2
114
+ echo -e "\n→ Converting 'algae'→@ and 'conta'→! in $(basename "$infile")..."
115
+ sed -E 's/algae/@/g; s/conta/!/g' "$infile" > "$outfile"
116
+ echo " ✔ Wrote $outfile"
117
+ }
118
+
119
+ convert_tags "$alg_out" "$alg_out_tagged"
120
+ convert_tags "$bac_out" "$bac_out_tagged"
121
+
122
+ # ---------------------- Metrics ----------------------
123
+ echo -e "\n→ Generating metrics report..."
124
+ singularity exec \
125
+ -B "$METRICS_SCRIPT:/metrics.py" \
126
+ -B "$(pwd):/workdir" \
127
+ "$SIF" \
128
+ bash -c "source /opt/conda/etc/profile.d/conda.sh && \
129
+ conda activate la4sr && cd /workdir && \
130
+ python3 /metrics.py \
131
+ \"$alg_out_tagged\" \"$bac_out_tagged\" \
132
+ -o \"$report\" \
133
+ -m \"$miscl\" \
134
+ -v \
135
+ -p \"results/$prefix\""
136
+
137
+ # ---------------------- Finished ----------------------
138
+ echo -e '\n🎉 Done! Results in ./results/'
139
+ echo " Algal TSV: $alg_out_tagged"
140
+ echo " Bact TSV: $bac_out_tagged"
141
+ echo " Report: $report"
142
+ echo " Misclassified: $miscl"
143
+ echo " Plots: results/${prefix}_*.png"
144
+
run_la4sr_loop.sbatch ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ #SBATCH -o slurm-logs/arrayJob_%A_%a.out
4
+ #SBATCH -e slurm-logs/arrayJob_%A_%a.err
5
+ #SBATCH -a 1-12 #5-112 # <-- set to length of the *longer* file
6
+ #SBATCH --mem=40G
7
+ #SBATCH --time=12:00:00
8
+ #SBATCH -p nvidia
9
+ #SBATCH --gres=gpu:1
10
+ #SBATCH --cpus-per-task=20
11
+
12
+ # Get line count of each file
13
+ NUM_ALGAE=$(wc -l < algae-filelist.txt)
14
+ NUM_CONTAM=$(wc -l < contam-filelist.txt)
15
+
16
+ # Use raw SLURM task ID
17
+ TASK_ID=$SLURM_ARRAY_TASK_ID
18
+
19
+ # Modulo wrap if needed
20
+ IDX_ALGAE=$(( (TASK_ID - 1) % NUM_ALGAE + 1 ))
21
+ IDX_CONTAM=$(( (TASK_ID - 1) % NUM_CONTAM + 1 ))
22
+
23
+ # Extract lines from files
24
+ ALINE=$(sed -n "${IDX_ALGAE}p" algae-filelist.txt)
25
+ CLINE=$(sed -n "${IDX_CONTAM}p" contam-filelist.txt)
26
+
27
+ # Run your classification script
28
+ ./run_la4sr_TI-inc-algaGPT.sh resume "$ALINE" "$CLINE"
29
+
30
+ ## EXAMPLE:
31
+
32
+ ##./run_la4sr.sh ./test-data/TI-free/AlgalTop10000-10holdout-headed.fa ./test-data/TI-free/BactTop10000-10holdout-headed.fa
slurm-10718799.out ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ algae-filelist.txt
2
+ cache/
3
+ ckpt.pt
4
+ contam-filelist.txt
5
+ filelist.txt
6
+ generated_prompts_algae1.txt_headed.fa
7
+ generated_prompts_algae2.txt_headed.fa
8
+ generated_prompts_algae3.txt_headed.fa
9
+ generated_prompts_archa1.txt_headed.fa
10
+ generated_prompts_archa2.txt_headed.fa
11
+ generated_prompts_archa3.txt_headed.fa
12
+ generated_prompts_bact1.txt_headed.fa
13
+ generated_prompts_bact2.txt_headed.fa
14
+ generated_prompts_bact3.txt_headed.fa
15
+ generated_prompts_fungi1.txt_headed.fa
16
+ generated_prompts_fungi2.txt_headed.fa
17
+ generated_prompts_fungi3.txt_headed.fa
18
+ generated_prompts_virus1.txt_headed.fa
19
+ generated_prompts_virus2.txt_headed.fa
20
+ generated_prompts_virus3.txt_headed.fa
21
+ infer_TI-inc-algaGPT.py
22
+ la4sr_sp2.sif
23
+ la4sr_sp2.sif.md5
24
+ llm-metrics-two-files.py
25
+ meta.pkl
26
+ model.py
27
+ out-algaGPT/
28
+ __pycache__/
29
+ __pycache__/model.cpython-312.pyc
30
+ README.txt
31
+ results-archive/
32
+ results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_bact3.txt_headed_misclassified.txt
33
+ results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_virus3.txt_headed_bacterial_tagged.tsv
34
+ results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_fungi2.txt_headed_metrics_by_class.png
35
+ results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_archa2.txt_headed_confusion_matrix.png
36
+ results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_bact3.txt_headed_algal_tagged.tsv
37
+ results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_bact2.txt_headed_report.txt
38
+ results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_fungi3.txt_headed_algal.tsv
39
+ results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_bact1.txt_headed_algal_tagged.tsv
40
+ results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_archa3.txt_headed_metrics_by_class.png
41
+ results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_fungi3.txt_headed_report.txt
42
+ results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_virus3.txt_headed_bacterial.tsv
43
+ results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_virus2.txt_headed_bacterial.tsv
44
+ results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_virus3.txt_headed_algal.tsv
45
+ results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_bact1.txt_headed_metrics_by_class.png
46
+ results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_bact1.txt_headed_bacterial.tsv
47
+ results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_fungi2.txt_headed_class_distribution.png
48
+ results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_archa2.txt_headed_metrics_by_class.png
49
+ results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_archa1.txt_headed_misclassified.txt
50
+ results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_bact2.txt_headed_algal.tsv
51
+ results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_virus1.txt_headed_bacterial.tsv
52
+ results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_virus1.txt_headed_metrics_by_class.png
53
+ results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_virus1.txt_headed_class_distribution.png
54
+ results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_archa1.txt_headed_bacterial_tagged.tsv
55
+ results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_bact3.txt_headed_metrics_by_class.png
56
+ results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_fungi3.txt_headed_misclassified.txt
57
+ results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_virus2.txt_headed_confusion_matrix.png
58
+ results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_fungi1.txt_headed_class_distribution.png
59
+ results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_virus1.txt_headed_algal.tsv
60
+ results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_fungi3.txt_headed_confusion_matrix.png
61
+ results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_archa2.txt_headed_bacterial_tagged.tsv
62
+ results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_archa1.txt_headed_bacterial.tsv
63
+ results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_fungi3.txt_headed_bacterial_tagged.tsv
64
+ results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_bact3.txt_headed_bacterial.tsv
65
+ results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_virus3.txt_headed_class_distribution.png
66
+ results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_fungi1.txt_headed_misclassified.txt
67
+ results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_virus3.txt_headed_report.txt
68
+ results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_archa2.txt_headed_algal_tagged.tsv
69
+ results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_archa3.txt_headed_class_distribution.png
70
+ results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_archa1.txt_headed_class_distribution.png
71
+ results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_virus2.txt_headed_misclassified.txt
72
+ results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_bact1.txt_headed_bacterial_tagged.tsv
73
+ results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_archa2.txt_headed_algal.tsv
74
+ results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_fungi1.txt_headed_algal.tsv
75
+ results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_virus2.txt_headed_class_distribution.png
76
+ results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_fungi2.txt_headed_algal_tagged.tsv
77
+ results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_virus1.txt_headed_bacterial_tagged.tsv
78
+ results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_archa3.txt_headed_bacterial.tsv
79
+ results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_virus1.txt_headed_report.txt
80
+ results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_bact3.txt_headed_algal.tsv
81
+ results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_archa1.txt_headed_algal.tsv
82
+ results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_virus2.txt_headed_metrics_by_class.png
83
+ results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_fungi2.txt_headed_algal.tsv
84
+ results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_bact1.txt_headed_algal.tsv
85
+ results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_fungi3.txt_headed_class_distribution.png
86
+ results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_archa2.txt_headed_class_distribution.png
87
+ results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_bact2.txt_headed_bacterial.tsv
88
+ results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_bact3.txt_headed_class_distribution.png
89
+ results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_virus3.txt_headed_algal_tagged.tsv
90
+ results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_fungi3.txt_headed_algal_tagged.tsv
91
+ results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_virus2.txt_headed_algal_tagged.tsv
92
+ results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_archa1.txt_headed_algal_tagged.tsv
93
+ results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_fungi2.txt_headed_confusion_matrix.png
94
+ results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_virus2.txt_headed_algal.tsv
95
+ results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_fungi1.txt_headed_algal_tagged.tsv
96
+ results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_archa3.txt_headed_report.txt
97
+ results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_archa2.txt_headed_report.txt
98
+ results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_archa3.txt_headed_algal.tsv
99
+ results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_archa2.txt_headed_misclassified.txt
100
+ results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_virus1.txt_headed_misclassified.txt
101
+ results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_fungi1.txt_headed_report.txt
102
+ results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_bact1.txt_headed_report.txt
103
+ results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_fungi2.txt_headed_misclassified.txt
104
+ results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_virus2.txt_headed_report.txt
105
+ results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_archa1.txt_headed_metrics_by_class.png
106
+ results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_bact2.txt_headed_confusion_matrix.png
107
+ results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_bact2.txt_headed_misclassified.txt
108
+ results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_archa2.txt_headed_bacterial.tsv
109
+ results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_fungi2.txt_headed_bacterial_tagged.tsv
110
+ results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_bact2.txt_headed_bacterial_tagged.tsv
111
+ results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_virus3.txt_headed_metrics_by_class.png
112
+ results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_bact2.txt_headed_class_distribution.png
113
+ results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_fungi2.txt_headed_bacterial.tsv
114
+ results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_archa3.txt_headed_bacterial_tagged.tsv
115
+ results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_archa1.txt_headed_report.txt
116
+ results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_virus2.txt_headed_bacterial_tagged.tsv
117
+ results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_bact1.txt_headed_class_distribution.png
118
+ results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_bact1.txt_headed_confusion_matrix.png
119
+ results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_virus3.txt_headed_confusion_matrix.png
120
+ results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_virus1.txt_headed_algal_tagged.tsv
121
+ results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_bact1.txt_headed_misclassified.txt
122
+ results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_fungi1.txt_headed_bacterial.tsv
123
+ results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_fungi3.txt_headed_metrics_by_class.png
124
+ results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_archa3.txt_headed_algal_tagged.tsv
125
+ results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_fungi1.txt_headed_bacterial_tagged.tsv
126
+ results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_bact3.txt_headed_report.txt
127
+ results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_bact3.txt_headed_bacterial_tagged.tsv
128
+ results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_fungi1.txt_headed_confusion_matrix.png
129
+ results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_fungi1.txt_headed_metrics_by_class.png
130
+ results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_bact2.txt_headed_metrics_by_class.png
131
+ results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_archa3.txt_headed_confusion_matrix.png
132
+ results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_virus1.txt_headed_confusion_matrix.png
133
+ results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_archa3.txt_headed_misclassified.txt
134
+ results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_bact2.txt_headed_algal_tagged.tsv
135
+ results-archive/generated_prompts_algae1.txt_headed_vs_generated_prompts_archa1.txt_headed_confusion_matrix.png
136
+ results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_virus3.txt_headed_misclassified.txt
137
+ results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_fungi3.txt_headed_bacterial.tsv
138
+ results-archive/generated_prompts_algae2.txt_headed_vs_generated_prompts_fungi2.txt_headed_report.txt
139
+ results-archive/generated_prompts_algae3.txt_headed_vs_generated_prompts_bact3.txt_headed_confusion_matrix.png
140
+ run_la4sr_loop.sbatch
141
+ run_la4sr_TI-inc-algaGPT.sh
142
+ slurm-10718799.out
143
+ tar: slurm-10718799.out: file changed as we read it
144
+ slurm-logs/
145
+ slurm-logs/arrayJob_10718778_2.err
146
+ slurm-logs/arrayJob_10718764_10.err
147
+ slurm-logs/arrayJob_10718778_8.err
148
+ slurm-logs/arrayJob_10718778_7.err
149
+ slurm-logs/arrayJob_10718764_9.out
150
+ slurm-logs/arrayJob_10718778_11.out
151
+ slurm-logs/arrayJob_10718764_4.err
152
+ slurm-logs/arrayJob_10718778_9.err
153
+ slurm-logs/arrayJob_10718764_10.out
154
+ slurm-logs/arrayJob_10718778_6.err
155
+ slurm-logs/arrayJob_10718778_1.err
156
+ slurm-logs/arrayJob_10718764_11.out
157
+ slurm-logs/arrayJob_10718778_12.out
158
+ slurm-logs/arrayJob_10718778_9.out
159
+ slurm-logs/arrayJob_10718764_11.err
160
+ slurm-logs/arrayJob_10718778_10.out
161
+ slurm-logs/arrayJob_10718764_9.err
162
+ slurm-logs/arrayJob_10718764_2.out
163
+ slurm-logs/arrayJob_10718778_2.out
164
+ slurm-logs/arrayJob_10718778_5.out
165
+ slurm-logs/arrayJob_10718764_12.out
166
+ slurm-logs/arrayJob_10718764_1.err
167
+ slurm-logs/arrayJob_10718764_8.err
168
+ slurm-logs/arrayJob_10718764_3.err
169
+ slurm-logs/arrayJob_10718778_1.out
170
+ slurm-logs/arrayJob_10718764_2.err
171
+ slurm-logs/arrayJob_10718778_10.err
172
+ slurm-logs/arrayJob_10718778_3.err
173
+ slurm-logs/arrayJob_10718778_5.err
174
+ slurm-logs/arrayJob_10718778_8.out
175
+ slurm-logs/arrayJob_10718778_4.out
176
+ slurm-logs/arrayJob_10718764_7.err
177
+ slurm-logs/arrayJob_10718778_3.out
178
+ slurm-logs/arrayJob_10718778_12.err
179
+ slurm-logs/arrayJob_10718764_3.out
180
+ slurm-logs/arrayJob_10718764_6.out
181
+ slurm-logs/arrayJob_10718764_5.err
182
+ slurm-logs/arrayJob_10718778_7.out
183
+ slurm-logs/arrayJob_10718764_7.out
184
+ slurm-logs/arrayJob_10718764_4.out
185
+ slurm-logs/arrayJob_10718764_6.err
186
+ slurm-logs/arrayJob_10718778_4.err
187
+ slurm-logs/arrayJob_10718764_5.out
188
+ slurm-logs/arrayJob_10718778_6.out
189
+ slurm-logs/arrayJob_10718764_1.out
190
+ slurm-logs/arrayJob_10718764_8.out
191
+ slurm-logs/arrayJob_10718764_12.err
192
+ slurm-logs/arrayJob_10718778_11.err
193
+ targz.sh
targz.sh ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ #SBATCH --mem=90GB
4
+ #SBATCH --time=96:00:00
5
+ #SBATCH --cpus-per-task=12
6
+
7
+ #tar -zcvf la4sr.tar.gz la4sr_sp* run_la4sr.sh run_la4sr.sbatch run_la4sr_loop.sbatch infer-ByT5tok-attn-fastaParser.py llm-metrics-two-files.py examples/Nannochloris_eucaryotum.AAC.fa.aa.fa examples/GCF_900016285.2_ANSES_11-930-9S.fa
8
+ tar -zcvf la4sr-TI-inc-algaGPT.tar.gz *
9
+