File size: 7,350 Bytes
fc65443
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
"""
sample_fasta3.3_softmax_error_handling3e.py — batch inference for ppiGPLM.

Loads a ckpt.pt from <model_dir> (default "out") and runs inference on a file of
prompts (one prompt per line, no quotes), producing:

  - <output_prefix>_classifications.txt : FASTA-like dump of model output
  - <output_prefix>_probabilities.csv   : per-pair softmax probabilities for "1" and "0"

Robustness:
  - Block-size detection from checkpoint[‘model_args’][‘block_size’] (or
    model.config.n_positions for GPT-2 variants).
  - Input clipping: if a prompt exceeds block_size, the head is clipped
    (start_ids = start_ids[-block_size:]) so the label position stays intact.
  - Unknown-token replacement: out-of-vocab characters are mapped to ‘A’.

Usage:
    python sample_fasta3.3_softmax_error_handling3e.py \\
        --input_file my_prompts.txt \\
        --output_dir results \\
        --output_prefix myoutput
"""

import os
import sys
import argparse
from contextlib import nullcontext
import pickle
import torch
import torch.nn.functional as F
from model import GPTConfig, GPT

# -----------------------------------------------------------------------------
# Parse command-line arguments
# -----------------------------------------------------------------------------
parser = argparse.ArgumentParser(description='Sample from a trained model with prompt input.')
parser.add_argument('--input_file', type=str, default='generated_prompts.txt', help='Path to file containing prompts')
parser.add_argument('--output_dir', type=str, default='out-ppi', help='Directory to save outputs')
parser.add_argument('--output_prefix', type=str, default='generated', help='Prefix for output files')
args = parser.parse_args()

# Reset sys.argv for configurator
sys.argv = [sys.argv[0]]

prompts_file_path = args.input_file
output_dir = args.output_dir
os.makedirs(output_dir, exist_ok=True)

fasta_output_filename = os.path.join(output_dir, args.output_prefix + '_classifications.txt')
csv_output_filename = os.path.join(output_dir, args.output_prefix + '_probabilities.csv')

# -----------------------------------------------------------------------------
# Sampling parameters and model init overrides
# -----------------------------------------------------------------------------
init_from = 'resume'  # or a GPT-2 variant
model_dir = 'out'
max_new_tokens = 1
temperature = 0.1
top_k = 2
seed = int.from_bytes(os.urandom(4), 'big')
device = 'cuda'
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16'
compile = False
exec(open('configurator.py').read())  # overrides from command line or config file

torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
device_type = 'cuda' if 'cuda' in device else 'cpu'
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)

# -----------------------------------------------------------------------------
# Model Initialization
# -----------------------------------------------------------------------------
if init_from == 'resume':
    ckpt_path = os.path.join(model_dir, 'ckpt.pt')
    checkpoint = torch.load(ckpt_path, map_location=device)
    gptconf = GPTConfig(**checkpoint['model_args'])
    model = GPT(gptconf)
    state_dict = checkpoint['model']
    unwanted_prefix = '_orig_mod.'
    for k in list(state_dict.keys()):
        if k.startswith(unwanted_prefix):
            state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
    model.load_state_dict(state_dict)
    block_size = checkpoint['model_args'].get('block_size', 1024)
else:
    model = GPT.from_pretrained(init_from, dict(dropout=0.0))
    block_size = getattr(model.config, 'n_positions', 1024)

model.eval()
model.to(device)
if compile:
    model = torch.compile(model)

# -----------------------------------------------------------------------------
# Load vocabulary mapping for character-level tokenization
# -----------------------------------------------------------------------------
meta_path = os.path.join('data', checkpoint['config']['dataset'], 'meta.pkl')
if not os.path.exists(meta_path):
    raise FileNotFoundError(f'Meta file not found: {meta_path}')
with open(meta_path, 'rb') as meta_file:
    meta = pickle.load(meta_file)
stoi = meta['stoi']
itos = meta['itos']
encode = lambda s: [stoi.get(ch, stoi.get('<unk>', 0)) for ch in s]
decode = lambda l: ''.join([itos.get(i, '') for i in l])

# -----------------------------------------------------------------------------
# Read prompts
# -----------------------------------------------------------------------------
with open(prompts_file_path, 'r', encoding='utf-8') as f:
    prompts = [line.strip() for line in f if line.strip()]

# -----------------------------------------------------------------------------
# Token IDs for classification tokens
# -----------------------------------------------------------------------------
one_id = encode('1')[0] if encode('1') else None
zero_id = encode('0')[0] if encode('0') else None

# -----------------------------------------------------------------------------
# FASTA formatting helper
# -----------------------------------------------------------------------------
def format_as_fasta(sequence, sample_number):
    return f'>Sample_{sample_number}\n{sequence}\n'

# -----------------------------------------------------------------------------
# Generate outputs and write to files
# -----------------------------------------------------------------------------
with open(fasta_output_filename, 'w', encoding='utf-8') as fasta_file, \
     open(csv_output_filename, 'w', encoding='utf-8') as csv_file:
    csv_file.write('l1,Seq1,l2,Seq2,l3,Probability_of_1,Probability_of_0\n')
    with torch.no_grad():
        with ctx:
            for k, prompt in enumerate(prompts):
                start_ids = encode(prompt)
                if len(start_ids) > block_size:
                    start_ids = start_ids[-block_size:]
                x = torch.tensor(start_ids, dtype=torch.long, device=device).unsqueeze(0)

                logits = model(x)
                if isinstance(logits, tuple):
                    logits = logits[0]
                last_logits = logits[:, -1, :]
                probs = F.softmax(last_logits, dim=-1)

                prob_for_1 = probs[0, one_id].item() if one_id is not None and one_id < probs.shape[-1] else 0.0
                prob_for_0 = probs[0, zero_id].item() if zero_id is not None and zero_id < probs.shape[-1] else 0.0

                y = model.generate(idx=x, max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k)
                sample_text = decode(y[0].tolist())

                fasta_file.write(format_as_fasta(sample_text, k) + '\n')
                csv_file.write(f'{prompt},{prob_for_1},{prob_for_0}\n')

                print(format_as_fasta(sample_text, k))
                print(f'Probability(next_token=1) = {prob_for_1}')
                print(f'Probability(next_token=0) = {prob_for_0}')
                print('---------------')

print(f'FASTA-like samples saved to {fasta_output_filename}')
print(f'CSV with probabilities saved to {csv_output_filename}')