| """ |
| sample_fasta3.3_softmax_error_handling3e.py — batch inference for ppiGPLM. |
| |
| Loads a ckpt.pt from <model_dir> (default "out") and runs inference on a file of |
| prompts (one prompt per line, no quotes), producing: |
| |
| - <output_prefix>_classifications.txt : FASTA-like dump of model output |
| - <output_prefix>_probabilities.csv : per-pair softmax probabilities for "1" and "0" |
| |
| Robustness: |
| - Block-size detection from checkpoint[‘model_args’][‘block_size’] (or |
| model.config.n_positions for GPT-2 variants). |
| - Input clipping: if a prompt exceeds block_size, the head is clipped |
| (start_ids = start_ids[-block_size:]) so the label position stays intact. |
| - Unknown-token replacement: out-of-vocab characters are mapped to ‘A’. |
| |
| Usage: |
| python sample_fasta3.3_softmax_error_handling3e.py \\ |
| --input_file my_prompts.txt \\ |
| --output_dir results \\ |
| --output_prefix myoutput |
| """ |
|
|
| import os |
| import sys |
| import argparse |
| from contextlib import nullcontext |
| import pickle |
| import torch |
| import torch.nn.functional as F |
| from model import GPTConfig, GPT |
|
|
| |
| |
| |
| parser = argparse.ArgumentParser(description='Sample from a trained model with prompt input.') |
| parser.add_argument('--input_file', type=str, default='generated_prompts.txt', help='Path to file containing prompts') |
| parser.add_argument('--output_dir', type=str, default='out-ppi', help='Directory to save outputs') |
| parser.add_argument('--output_prefix', type=str, default='generated', help='Prefix for output files') |
| args = parser.parse_args() |
|
|
| |
| sys.argv = [sys.argv[0]] |
|
|
| prompts_file_path = args.input_file |
| output_dir = args.output_dir |
| os.makedirs(output_dir, exist_ok=True) |
|
|
| fasta_output_filename = os.path.join(output_dir, args.output_prefix + '_classifications.txt') |
| csv_output_filename = os.path.join(output_dir, args.output_prefix + '_probabilities.csv') |
|
|
| |
| |
| |
| init_from = 'resume' |
| model_dir = 'out' |
| max_new_tokens = 1 |
| temperature = 0.1 |
| top_k = 2 |
| seed = int.from_bytes(os.urandom(4), 'big') |
| device = 'cuda' |
| dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' |
| compile = False |
| exec(open('configurator.py').read()) |
|
|
| torch.manual_seed(seed) |
| torch.cuda.manual_seed(seed) |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
| device_type = 'cuda' if 'cuda' in device else 'cpu' |
| ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] |
| ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype) |
|
|
| |
| |
| |
| if init_from == 'resume': |
| ckpt_path = os.path.join(model_dir, 'ckpt.pt') |
| checkpoint = torch.load(ckpt_path, map_location=device) |
| gptconf = GPTConfig(**checkpoint['model_args']) |
| model = GPT(gptconf) |
| state_dict = checkpoint['model'] |
| unwanted_prefix = '_orig_mod.' |
| for k in list(state_dict.keys()): |
| if k.startswith(unwanted_prefix): |
| state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) |
| model.load_state_dict(state_dict) |
| block_size = checkpoint['model_args'].get('block_size', 1024) |
| else: |
| model = GPT.from_pretrained(init_from, dict(dropout=0.0)) |
| block_size = getattr(model.config, 'n_positions', 1024) |
|
|
| model.eval() |
| model.to(device) |
| if compile: |
| model = torch.compile(model) |
|
|
| |
| |
| |
| meta_path = os.path.join('data', checkpoint['config']['dataset'], 'meta.pkl') |
| if not os.path.exists(meta_path): |
| raise FileNotFoundError(f'Meta file not found: {meta_path}') |
| with open(meta_path, 'rb') as meta_file: |
| meta = pickle.load(meta_file) |
| stoi = meta['stoi'] |
| itos = meta['itos'] |
| encode = lambda s: [stoi.get(ch, stoi.get('<unk>', 0)) for ch in s] |
| decode = lambda l: ''.join([itos.get(i, '') for i in l]) |
|
|
| |
| |
| |
| with open(prompts_file_path, 'r', encoding='utf-8') as f: |
| prompts = [line.strip() for line in f if line.strip()] |
|
|
| |
| |
| |
| one_id = encode('1')[0] if encode('1') else None |
| zero_id = encode('0')[0] if encode('0') else None |
|
|
| |
| |
| |
| def format_as_fasta(sequence, sample_number): |
| return f'>Sample_{sample_number}\n{sequence}\n' |
|
|
| |
| |
| |
| with open(fasta_output_filename, 'w', encoding='utf-8') as fasta_file, \ |
| open(csv_output_filename, 'w', encoding='utf-8') as csv_file: |
| csv_file.write('l1,Seq1,l2,Seq2,l3,Probability_of_1,Probability_of_0\n') |
| with torch.no_grad(): |
| with ctx: |
| for k, prompt in enumerate(prompts): |
| start_ids = encode(prompt) |
| if len(start_ids) > block_size: |
| start_ids = start_ids[-block_size:] |
| x = torch.tensor(start_ids, dtype=torch.long, device=device).unsqueeze(0) |
|
|
| logits = model(x) |
| if isinstance(logits, tuple): |
| logits = logits[0] |
| last_logits = logits[:, -1, :] |
| probs = F.softmax(last_logits, dim=-1) |
|
|
| prob_for_1 = probs[0, one_id].item() if one_id is not None and one_id < probs.shape[-1] else 0.0 |
| prob_for_0 = probs[0, zero_id].item() if zero_id is not None and zero_id < probs.shape[-1] else 0.0 |
|
|
| y = model.generate(idx=x, max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k) |
| sample_text = decode(y[0].tolist()) |
|
|
| fasta_file.write(format_as_fasta(sample_text, k) + '\n') |
| csv_file.write(f'{prompt},{prob_for_1},{prob_for_0}\n') |
|
|
| print(format_as_fasta(sample_text, k)) |
| print(f'Probability(next_token=1) = {prob_for_1}') |
| print(f'Probability(next_token=0) = {prob_for_0}') |
| print('---------------') |
|
|
| print(f'FASTA-like samples saved to {fasta_output_filename}') |
| print(f'CSV with probabilities saved to {csv_output_filename}') |
|
|