File size: 7,350 Bytes
fc65443 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 | """
sample_fasta3.3_softmax_error_handling3e.py — batch inference for ppiGPLM.
Loads a ckpt.pt from <model_dir> (default "out") and runs inference on a file of
prompts (one prompt per line, no quotes), producing:
- <output_prefix>_classifications.txt : FASTA-like dump of model output
- <output_prefix>_probabilities.csv : per-pair softmax probabilities for "1" and "0"
Robustness:
- Block-size detection from checkpoint[‘model_args’][‘block_size’] (or
model.config.n_positions for GPT-2 variants).
- Input clipping: if a prompt exceeds block_size, the head is clipped
(start_ids = start_ids[-block_size:]) so the label position stays intact.
- Unknown-token replacement: out-of-vocab characters are mapped to ‘A’.
Usage:
python sample_fasta3.3_softmax_error_handling3e.py \\
--input_file my_prompts.txt \\
--output_dir results \\
--output_prefix myoutput
"""
import os
import sys
import argparse
from contextlib import nullcontext
import pickle
import torch
import torch.nn.functional as F
from model import GPTConfig, GPT
# -----------------------------------------------------------------------------
# Parse command-line arguments
# -----------------------------------------------------------------------------
parser = argparse.ArgumentParser(description='Sample from a trained model with prompt input.')
parser.add_argument('--input_file', type=str, default='generated_prompts.txt', help='Path to file containing prompts')
parser.add_argument('--output_dir', type=str, default='out-ppi', help='Directory to save outputs')
parser.add_argument('--output_prefix', type=str, default='generated', help='Prefix for output files')
args = parser.parse_args()
# Reset sys.argv for configurator
sys.argv = [sys.argv[0]]
prompts_file_path = args.input_file
output_dir = args.output_dir
os.makedirs(output_dir, exist_ok=True)
fasta_output_filename = os.path.join(output_dir, args.output_prefix + '_classifications.txt')
csv_output_filename = os.path.join(output_dir, args.output_prefix + '_probabilities.csv')
# -----------------------------------------------------------------------------
# Sampling parameters and model init overrides
# -----------------------------------------------------------------------------
init_from = 'resume' # or a GPT-2 variant
model_dir = 'out'
max_new_tokens = 1
temperature = 0.1
top_k = 2
seed = int.from_bytes(os.urandom(4), 'big')
device = 'cuda'
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16'
compile = False
exec(open('configurator.py').read()) # overrides from command line or config file
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
device_type = 'cuda' if 'cuda' in device else 'cpu'
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
# -----------------------------------------------------------------------------
# Model Initialization
# -----------------------------------------------------------------------------
if init_from == 'resume':
ckpt_path = os.path.join(model_dir, 'ckpt.pt')
checkpoint = torch.load(ckpt_path, map_location=device)
gptconf = GPTConfig(**checkpoint['model_args'])
model = GPT(gptconf)
state_dict = checkpoint['model']
unwanted_prefix = '_orig_mod.'
for k in list(state_dict.keys()):
if k.startswith(unwanted_prefix):
state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
model.load_state_dict(state_dict)
block_size = checkpoint['model_args'].get('block_size', 1024)
else:
model = GPT.from_pretrained(init_from, dict(dropout=0.0))
block_size = getattr(model.config, 'n_positions', 1024)
model.eval()
model.to(device)
if compile:
model = torch.compile(model)
# -----------------------------------------------------------------------------
# Load vocabulary mapping for character-level tokenization
# -----------------------------------------------------------------------------
meta_path = os.path.join('data', checkpoint['config']['dataset'], 'meta.pkl')
if not os.path.exists(meta_path):
raise FileNotFoundError(f'Meta file not found: {meta_path}')
with open(meta_path, 'rb') as meta_file:
meta = pickle.load(meta_file)
stoi = meta['stoi']
itos = meta['itos']
encode = lambda s: [stoi.get(ch, stoi.get('<unk>', 0)) for ch in s]
decode = lambda l: ''.join([itos.get(i, '') for i in l])
# -----------------------------------------------------------------------------
# Read prompts
# -----------------------------------------------------------------------------
with open(prompts_file_path, 'r', encoding='utf-8') as f:
prompts = [line.strip() for line in f if line.strip()]
# -----------------------------------------------------------------------------
# Token IDs for classification tokens
# -----------------------------------------------------------------------------
one_id = encode('1')[0] if encode('1') else None
zero_id = encode('0')[0] if encode('0') else None
# -----------------------------------------------------------------------------
# FASTA formatting helper
# -----------------------------------------------------------------------------
def format_as_fasta(sequence, sample_number):
return f'>Sample_{sample_number}\n{sequence}\n'
# -----------------------------------------------------------------------------
# Generate outputs and write to files
# -----------------------------------------------------------------------------
with open(fasta_output_filename, 'w', encoding='utf-8') as fasta_file, \
open(csv_output_filename, 'w', encoding='utf-8') as csv_file:
csv_file.write('l1,Seq1,l2,Seq2,l3,Probability_of_1,Probability_of_0\n')
with torch.no_grad():
with ctx:
for k, prompt in enumerate(prompts):
start_ids = encode(prompt)
if len(start_ids) > block_size:
start_ids = start_ids[-block_size:]
x = torch.tensor(start_ids, dtype=torch.long, device=device).unsqueeze(0)
logits = model(x)
if isinstance(logits, tuple):
logits = logits[0]
last_logits = logits[:, -1, :]
probs = F.softmax(last_logits, dim=-1)
prob_for_1 = probs[0, one_id].item() if one_id is not None and one_id < probs.shape[-1] else 0.0
prob_for_0 = probs[0, zero_id].item() if zero_id is not None and zero_id < probs.shape[-1] else 0.0
y = model.generate(idx=x, max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k)
sample_text = decode(y[0].tolist())
fasta_file.write(format_as_fasta(sample_text, k) + '\n')
csv_file.write(f'{prompt},{prob_for_1},{prob_for_0}\n')
print(format_as_fasta(sample_text, k))
print(f'Probability(next_token=1) = {prob_for_1}')
print(f'Probability(next_token=0) = {prob_for_0}')
print('---------------')
print(f'FASTA-like samples saved to {fasta_output_filename}')
print(f'CSV with probabilities saved to {csv_output_filename}')
|