Spaces:
Running
Running
| """ | |
| Copyright (c) 2023, salesforce.com, inc. | |
| All rights reserved. | |
| SPDX-License-Identifier: BSD-3-Clause | |
| For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | |
| """ | |
| import logging | |
| import os | |
| import torch | |
| import numpy as np | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from collections.abc import Mapping | |
| from torch.cuda.amp import autocast as autocast | |
| from lavis.models.blip2_models.blip2 import disabled_train | |
| from model.blip2 import Blip2Base | |
| from transformers import AutoTokenizer | |
| from transformers import OPTForCausalLM | |
| from opendelta import LoraModel | |
| from torch.cuda.amp import autocast | |
| from opendelta.delta_models.lora import LoraConfig as DeltaLoraConfig | |
| from transformers import BertTokenizer, BitsAndBytesConfig | |
| from transformers.tokenization_utils_base import BatchEncoding | |
| from model.help_funcs import hf_enable_gradient_checkpointing | |
| opt_model_list = [ | |
| "facebook/galactica-125m", | |
| "facebook/galactica-1.3b", | |
| "facebook/galactica-6.7b", | |
| "facebook/galactica-30b", | |
| ] | |
| def get_gpu_memory(device=0): | |
| free, total = torch.cuda.mem_get_info(device) | |
| free = free / (1024 ** 3) | |
| total = total / (1024 ** 3) | |
| return free, total-free, total | |
| # Reliability head: 4-class classification mapping | |
| # Class indices: {-1: 0, 0: 1, 0.5: 2, 1: 3} | |
| RELIABILITY_CLASSES = [-1.0, 0.0, 0.5, 1.0] | |
| RELIABILITY_VAL_TO_IDX = {v: i for i, v in enumerate(RELIABILITY_CLASSES)} | |
| RELIABILITY_NUM_CLASSES = len(RELIABILITY_CLASSES) | |
| def _r_value_to_class(r_tensor): | |
| """Convert reliability float values to class indices.""" | |
| class_indices = torch.zeros_like(r_tensor, dtype=torch.long) | |
| for val, idx in RELIABILITY_VAL_TO_IDX.items(): | |
| class_indices[torch.isclose(r_tensor, torch.tensor(val, device=r_tensor.device, dtype=r_tensor.dtype), atol=1e-4)] = idx | |
| return class_indices | |
| def _class_to_r_value(class_indices): | |
| """Convert class indices back to reliability float values.""" | |
| classes = torch.tensor(RELIABILITY_CLASSES, dtype=torch.float32, device=class_indices.device) | |
| return classes[class_indices] | |
| def _r_value_to_binary_class(r_tensor): | |
| """Binary mapping: 1 if r ~= 1.0 (positive class), else 0.""" | |
| pos = torch.isclose(r_tensor, torch.tensor(1.0, device=r_tensor.device, dtype=r_tensor.dtype), atol=1e-4) | |
| return pos.long() | |
| def _scan_r_counts(json_path): | |
| """Return (raw_counts_4class, total). raw_counts indexed by RELIABILITY_VAL_TO_IDX.""" | |
| import json as _json | |
| counts = torch.zeros(RELIABILITY_NUM_CLASSES, dtype=torch.float32) | |
| if not json_path or not os.path.isfile(json_path): | |
| return counts, 0 | |
| total = 0 | |
| with open(json_path, 'r', encoding='utf-8') as f: | |
| for line in f: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| row = _json.loads(line) | |
| r = float(row[2]) | |
| total += 1 | |
| for cls_val, idx in RELIABILITY_VAL_TO_IDX.items(): | |
| if abs(r - cls_val) < 1e-4: | |
| counts[idx] += 1 | |
| break | |
| return counts, total | |
| def _inverse_freq_from_counts(counts): | |
| """Inverse-frequency weights w_k = N / (K_present * n_k), mean weight over | |
| present classes = 1. Empty classes get weight 0.""" | |
| if counts.sum() == 0: | |
| return None | |
| nonzero = counts > 0 | |
| weights = torch.zeros_like(counts) | |
| weights[nonzero] = counts.sum() / (nonzero.sum().float() * counts[nonzero]) | |
| return weights | |
| def _compute_inverse_freq_class_weights(json_path): | |
| """4-class inverse-frequency weights from JSON-lines file.""" | |
| counts, total = _scan_r_counts(json_path) | |
| if total == 0: | |
| return None | |
| return _inverse_freq_from_counts(counts) | |
| def _compute_binary_inverse_freq_class_weights(json_path): | |
| """Binary inverse-frequency weights. Index 0 = negative (r != 1), index 1 = positive (r == 1).""" | |
| counts, total = _scan_r_counts(json_path) | |
| if total == 0: | |
| return None | |
| pos = counts[RELIABILITY_VAL_TO_IDX[1.0]] | |
| neg = counts.sum() - pos | |
| binary_counts = torch.tensor([neg, pos], dtype=torch.float32) | |
| return _inverse_freq_from_counts(binary_counts) | |
| def mask_by_len(input, lens, fill_value=0): | |
| ''' | |
| input: shape = [N, D] | |
| lens: shape = [N] | |
| ''' | |
| mask = torch.arange(input.shape[1], device=input.device).reshape(1, -1) | |
| mask = mask < lens.reshape(-1, 1) | |
| input[mask] = fill_value | |
| return input | |
| class Blip2OPT(Blip2Base): | |
| """ | |
| BLIP2 model for protein function prediction with Q-former (generation only). | |
| """ | |
| def __init__( | |
| self, | |
| bert_name, | |
| num_query_token=32, | |
| cross_attention_freq=2, | |
| plm_model="facebook/esm2_t30_150M_UR50D", | |
| plm_tune='freeze', | |
| llm_name="facebook/galactica-1.3b", | |
| llm_tune='freeze', | |
| peft_dir='', | |
| args=None, | |
| ): | |
| super().__init__() | |
| self.args = args | |
| self.id2go = getattr(args, 'id2go', {}) | |
| self.do_sample = args.do_sample | |
| self.num_beams = args.num_beams | |
| self.max_inference_len = args.max_inference_len | |
| self.min_inference_len = args.min_inference_len | |
| self.enbale_gradient_checkpointing = args.enbale_gradient_checkpointing | |
| self.plm_lora_r = args.plm_lora_r | |
| self.plm_lora_alpha = args.plm_lora_alpha | |
| self.plm_lora_dropout = args.plm_lora_dropout | |
| # Detect encoder type based on model name | |
| self.plm_model = plm_model | |
| self.is_esmc = str(plm_model).startswith('esmc_') | |
| self.plm_tokenizer, self.plm, self.ln_layer = self.init_protein_encoder(plm_model) | |
| self.plm_tune = plm_tune | |
| if plm_tune == 'freeze': | |
| for name, param in self.plm.named_parameters(): | |
| param.requires_grad = False | |
| self.plm = self.plm.eval() | |
| self.plm.train = disabled_train | |
| logging.info("freeze plm encoder") | |
| elif plm_tune == 'lora': | |
| plm_cfg = DeltaLoraConfig(self.plm_lora_r, | |
| self.plm_lora_alpha, | |
| self.plm_lora_dropout, | |
| modified_modules=["attn.layernorm_qkv.1","ffn.1"]) | |
| self.plm_delta = LoraModel.from_config(plm_cfg, self.plm) | |
| self.plm_delta.freeze_module(set_state_dict=False) | |
| self.plm_delta.log() | |
| else: | |
| raise NotImplementedError() | |
| self.num_query_token = num_query_token | |
| _, self.Qformer, self.query_tokens = self.init_Qformer(bert_name, num_query_token, self.plm.num_features, cross_attention_freq) | |
| self.Qformer.cls = None | |
| self.Qformer.bert.embeddings.word_embeddings = None | |
| self.Qformer.bert.embeddings.position_embeddings = None | |
| for layer in self.Qformer.bert.encoder.layer: | |
| layer.output = None | |
| layer.intermediate = None | |
| for p in self.ln_layer.parameters(): | |
| p.requires_grad = False | |
| for p in self.plm.parameters(): | |
| p.requires_grad = False | |
| self.llm_model, self.llm_tokenizer = self.load_llm(llm_name, load_4bit=False, enable_gradient_checkpointing=self.enbale_gradient_checkpointing) | |
| if llm_tune == 'freeze': | |
| for name, param in self.llm_model.named_parameters(): | |
| param.requires_grad = False | |
| elif llm_tune == 'full': | |
| for name, param in self.llm_model.named_parameters(): | |
| param.requires_grad = True | |
| elif llm_tune == 'lora': | |
| lora_config = DeltaLoraConfig(args.lora_r, | |
| args.lora_alpha, | |
| args.lora_dropout, | |
| ) | |
| self.lm_delta = LoraModel.from_config(lora_config, self.llm_model) | |
| self.lm_delta.freeze_module(set_state_dict=False) | |
| self.lm_delta.log() | |
| elif llm_tune == 'mid_lora': | |
| lora_config = DeltaLoraConfig(args.lora_r, args.lora_alpha, args.lora_dropout, modified_modules=["q_proj", "v_proj", 'k_proj', "out_proj", "fc1", "fc2"]) | |
| self.delta = LoraModel.from_config(lora_config, self.llm_model) | |
| self.delta.freeze_module(set_state_dict=False) | |
| self.delta.log() | |
| elif llm_tune == 'peft_lora': | |
| config = PeftLoraConfig( | |
| r=args.lora_r, | |
| lora_alpha=args.lora_alpha, | |
| # target_modules=modules, | |
| lora_dropout=args.lora_dropout, | |
| bias="none", | |
| task_type="CAUSAL_LM", | |
| ) | |
| self.llm_model = get_peft_model(self.llm_model, config) | |
| for name, module in self.llm_model.named_modules(): | |
| if isinstance(module, LoraLayer): | |
| if True: | |
| module = module.to(torch.bfloat16) | |
| if 'norm' in name: | |
| module = module.to(torch.float32) | |
| if 'lm_head' in name or 'embed_tokens' in name: | |
| if hasattr(module, 'weight'): | |
| if True and module.weight.dtype == torch.float32: | |
| module = module.to(torch.bfloat16) | |
| else: | |
| raise NotImplementedError() | |
| self.eos_token_id = self.llm_tokenizer( | |
| "\n", add_special_tokens=False | |
| ).input_ids[0] | |
| self.opt_proj = nn.Linear(self.Qformer.config.hidden_size, self.llm_model.config.hidden_size) | |
| # Reliability head can be either 4-class ({-1, 0, 0.5, 1}) or binary (positive=r==1, else negative). | |
| self.reliability_binary = bool(getattr(args, 'reliability_binary', False)) | |
| head_out_dim = 2 if self.reliability_binary else RELIABILITY_NUM_CLASSES | |
| reliability_input_dim = self.Qformer.config.hidden_size + self.plm.num_features | |
| self.reliability_head = nn.Sequential( | |
| nn.Dropout(0.3), | |
| nn.Linear(reliability_input_dim, head_out_dim), | |
| ) | |
| self.reliability_head = self.reliability_head.to(torch.float32) | |
| # Class weights for reliability cross-entropy. Computed from the reliability | |
| # finetune JSON when train_reliability_head_only is set; uniform otherwise. | |
| cw = None | |
| if getattr(args, 'train_reliability_head_only', False): | |
| json_path = getattr(args, 'reliability_finetune_data', '') | |
| if self.reliability_binary: | |
| cw = _compute_binary_inverse_freq_class_weights(json_path) | |
| else: | |
| cw = _compute_inverse_freq_class_weights(json_path) | |
| if cw is not None: | |
| mode_tag = 'binary' if self.reliability_binary else '4-class' | |
| logging.info(f"reliability {mode_tag} inverse-freq class weights: {cw.tolist()}") | |
| if cw is None: | |
| cw = torch.ones(head_out_dim, dtype=torch.float32) | |
| self.register_buffer('reliability_class_weights', cw) | |
| def load_llm(self, llm_model, load_4bit=False, enable_gradient_checkpointing=True): | |
| llm_tokenizer = AutoTokenizer.from_pretrained(llm_model, use_fast=False, padding_side='right') | |
| llm_tokenizer.add_special_tokens({'pad_token': '<pad>'}) | |
| if load_4bit: | |
| quant_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| load_in_8bit=False, | |
| llm_int8_threshold=6.0, | |
| llm_int8_has_fp16_weight=False, | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type='nf4', | |
| ) | |
| import os | |
| visible_devices = os.environ.get('CUDA_VISIBLE_DEVICES', '0') | |
| device_ids = [int(d) for d in visible_devices.split(',') if d.strip()] | |
| outputs = get_gpu_memory(device_ids[0] if device_ids else 0) | |
| used_memory = outputs[1] | |
| if used_memory > 1 and len(device_ids) > 1: | |
| device_map = {"": device_ids[1]} | |
| else: | |
| device_map = {"": device_ids[0] if device_ids else 0} | |
| llm_model = OPTForCausalLM.from_pretrained( | |
| llm_model, | |
| quantization_config=quant_config, | |
| load_in_4bit=True, | |
| load_in_8bit=False, | |
| device_map=device_map, | |
| torch_dtype=torch.bfloat16, | |
| ) | |
| llm_model.resize_token_embeddings(len(llm_tokenizer)) | |
| llm_model = prepare_model_for_kbit_training(llm_model, use_gradient_checkpointing=True) | |
| else: | |
| llm_model = OPTForCausalLM.from_pretrained(llm_model, torch_dtype=torch.bfloat16) | |
| llm_model.resize_token_embeddings(len(llm_tokenizer)) | |
| if enable_gradient_checkpointing: | |
| llm_model = hf_enable_gradient_checkpointing(llm_model) | |
| return llm_model, llm_tokenizer | |
| def _normalize_batch(self, batch): | |
| """Normalize batch to dict format.""" | |
| if isinstance(batch, (Mapping, BatchEncoding)): | |
| return batch | |
| if isinstance(batch, (list, tuple)): | |
| for x in batch: | |
| if isinstance(x, (Mapping, BatchEncoding)) and "input_ids" in x: | |
| return x | |
| raise TypeError(f"Batch must be a mapping with input_ids; got {type(batch)}") | |
| def prot_encode(self, prot_batch): | |
| """Encode protein sequence to embeddings.""" | |
| prot_batch = self._normalize_batch(prot_batch) | |
| device = next(self.plm.parameters()).device | |
| raw_ids = prot_batch.get("input_ids") | |
| raw_mask = prot_batch.get("attention_mask", None) | |
| pad_id = getattr(self.plm_tokenizer, "pad_token_id", 1) | |
| if torch.is_tensor(raw_ids): | |
| input_ids = raw_ids.to(device) | |
| attention_mask = raw_mask.to(device) if torch.is_tensor(raw_mask) else None | |
| elif isinstance(raw_ids, np.ndarray): | |
| input_ids = torch.as_tensor(raw_ids, dtype=torch.long, device=device) | |
| attention_mask = torch.as_tensor(raw_mask, dtype=torch.long, device=device) if raw_mask is not None else None | |
| elif isinstance(raw_ids, list): | |
| ids_list = [raw_ids] if (not raw_ids or not isinstance(raw_ids[0], list)) else raw_ids | |
| Lfix = getattr(self, "prot_max_len", None) or max(len(x) for x in ids_list) | |
| padded, masks = [], [] | |
| for ids in ids_list: | |
| curr_len = len(ids) | |
| ids = ids[:Lfix] + [pad_id] * max(0, Lfix - curr_len) | |
| padded.append(ids) | |
| masks.append([1] * min(curr_len, Lfix) + [0] * max(0, Lfix - curr_len)) | |
| input_ids = torch.tensor(padded, dtype=torch.long, device=device) | |
| attention_mask = torch.tensor(masks, dtype=torch.long, device=device) | |
| else: | |
| raise TypeError(f"Unsupported input_ids type: {type(raw_ids)}") | |
| if attention_mask is None: | |
| attention_mask = (input_ids != pad_id).long() | |
| out = self.plm(input_ids=input_ids, attention_mask=attention_mask) | |
| prot_embeds = out.last_hidden_state if hasattr(out, "last_hidden_state") else out[0] | |
| if self.plm_tune == "freeze": | |
| prot_embeds = prot_embeds.detach() | |
| ln_dtype = next(self.ln_layer.parameters()).dtype | |
| if prot_embeds.dtype != ln_dtype: | |
| prot_embeds = prot_embeds.to(ln_dtype) | |
| return self.ln_layer(prot_embeds), attention_mask | |
| def prot_qformer(self, prot_embeds, attention_mask): | |
| """Apply Q-former to protein embeddings.""" | |
| B = prot_embeds.size(0) | |
| query_tokens = self.query_tokens.expand(B, -1, -1) | |
| q_out = self.Qformer.bert( | |
| query_embeds=query_tokens, | |
| encoder_hidden_states=prot_embeds, | |
| encoder_attention_mask=attention_mask, | |
| return_dict=True, | |
| ) | |
| qformer_raw = q_out.last_hidden_state # [B, num_query, qformer_hidden] | |
| prot_tokens = self.opt_proj(qformer_raw) | |
| prot_feats = F.normalize(prot_tokens.mean(dim=1), dim=-1, p=2) | |
| # Mean-pooled Q-Former features before projection (for reliability head) | |
| qformer_feats = qformer_raw.mean(dim=1).detach() | |
| # Mean-pooled PLM embeddings (masked) for reliability head | |
| plm_pooled = (prot_embeds * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(dim=1, keepdim=True).clamp(min=1) | |
| plm_pooled = plm_pooled.detach() | |
| return prot_tokens, prot_feats, qformer_feats, plm_pooled | |
| def _extract_reliability_features(self, last_hidden, attention_mask, prot_feats, protein_token_count): | |
| """ | |
| Extract features for reliability prediction: masked mean of LLM last hidden layer. | |
| Args: | |
| last_hidden: Hidden states from LLM [B, seq_len, hidden_size] | |
| attention_mask: Attention mask for valid tokens [B, seq_len] | |
| prot_feats: (unused, kept for API compatibility) | |
| protein_token_count: (unused, kept for API compatibility) | |
| Returns: | |
| Mean-pooled feature tensor [B, hidden_size] | |
| """ | |
| last_hidden_fp32 = last_hidden.to(torch.float32) | |
| mask_expanded = attention_mask.unsqueeze(-1).to(torch.float32) | |
| hidden_masked = last_hidden_fp32 * mask_expanded | |
| seq_lengths = mask_expanded.sum(dim=1).clamp(min=1) | |
| mean_pool = hidden_masked.sum(dim=1) / seq_lengths | |
| return mean_pool.detach() | |
| def forward(self, batch, return_pred=False): | |
| """Forward pass for training (generation + reliability classification).""" | |
| prot_batch, text_batch, prompt_batch, r_tensor = batch | |
| samples = {'prot_batch': prot_batch, 'prompt_batch': prompt_batch, 'reliability': r_tensor} | |
| prot_embeds, prot_attn = self.prot_encode(prot_batch) | |
| prot_tokens, prot_feats, qformer_feats, plm_pooled = self.prot_qformer(prot_embeds, prot_attn) | |
| device = prot_embeds.device | |
| text_batch = self._normalize_batch(text_batch) | |
| input_ids = text_batch["input_ids"].to(device) | |
| attention_mask = text_batch["attention_mask"].to(device) | |
| token_type_ids = text_batch.get("token_type_ids") | |
| targets = input_ids.masked_fill(input_ids == self.llm_tokenizer.pad_token_id, -100) | |
| if token_type_ids is not None: | |
| targets = targets.masked_fill(token_type_ids.to(device) == 0, -100) | |
| B, Q = prot_tokens.shape[:2] | |
| prot_mask = torch.ones((B, Q), dtype=attention_mask.dtype, device=device) | |
| prot_empty_targets = torch.full((B, Q), -100, dtype=torch.long, device=device) | |
| text_embeds = self.llm_model.get_input_embeddings()(input_ids) | |
| if prot_tokens.dtype != text_embeds.dtype: | |
| prot_tokens = prot_tokens.to(text_embeds.dtype) | |
| inputs_embeds = torch.cat([prot_tokens, text_embeds], dim=1) | |
| full_attn = torch.cat([prot_mask, attention_mask], dim=1) | |
| targets = torch.cat([prot_empty_targets, targets], dim=1) | |
| outputs = self.llm_model( | |
| inputs_embeds=inputs_embeds, | |
| attention_mask=full_attn, | |
| return_dict=True, | |
| labels=targets, | |
| output_hidden_states=True | |
| ) | |
| last_hidden = outputs.hidden_states[-1] | |
| last_hidden = last_hidden[0] if isinstance(last_hidden, (tuple, list)) else last_hidden | |
| with autocast(enabled=False): | |
| head_param = next(self.reliability_head.parameters()) | |
| qf = qformer_feats.to(device=head_param.device, dtype=head_param.dtype) | |
| pf = plm_pooled.to(device=head_param.device, dtype=head_param.dtype) | |
| r_input = torch.cat([qf, pf], dim=-1) | |
| r_logits = self.reliability_head(r_input) # [B, NUM_CLASSES] or [B, 2] | |
| r_target = r_tensor.to(device=r_logits.device, dtype=r_logits.dtype).view(-1) | |
| if self.reliability_binary: | |
| r_class_target = _r_value_to_binary_class(r_target).to(r_logits.device) | |
| else: | |
| r_class_target = _r_value_to_class(r_target).to(r_logits.device) | |
| cw = self.reliability_class_weights.to(device=r_logits.device, dtype=r_logits.dtype) | |
| r_loss = nn.functional.cross_entropy(r_logits, r_class_target, weight=cw) | |
| if return_pred: | |
| with torch.no_grad(): | |
| pred_texts, r_pred, conf, _, _prob = self.generate( | |
| samples, | |
| do_sample=self.do_sample, | |
| num_beams=self.num_beams, | |
| max_length=self.max_inference_len, | |
| min_length=self.min_inference_len | |
| ) | |
| return outputs.loss, r_loss, pred_texts, r_pred | |
| return outputs.loss, r_loss | |
| def generate( | |
| self, | |
| samples, | |
| do_sample=False, | |
| num_beams=3, | |
| max_length=128, | |
| min_length=1, | |
| top_p=0.9, | |
| repetition_penalty=1.3, | |
| length_penalty=1.0, | |
| num_captions=1, | |
| temperature=1, | |
| ): | |
| """ | |
| Args: | |
| samples (dict): A dictionary containing the following keys: | |
| - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W) | |
| num_beams (int): Number of beams for beam search. 1 means no beam search. | |
| max_length (int): The maximum length of the sequence to be generated. | |
| min_length (int): The minimum length of the sequence to be generated. | |
| top_p (float): The cumulative probability for nucleus sampling. | |
| repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty. | |
| num_captions (int): Number of captions to be generated for each image. | |
| Returns: | |
| captions (list): A list of strings of length batch_size * num_captions. | |
| """ | |
| prot_batch = samples['prot_batch'] | |
| prompt_batch = samples['prompt_batch'] | |
| # Encode protein sequence | |
| prot_embeds, prot_attn = self.prot_encode(prot_batch) | |
| prot_tokens, prot_feats, qformer_feats, plm_pooled = self.prot_qformer(prot_embeds, prot_attn) | |
| device = prot_embeds.device | |
| prompt_embeds = self.llm_model.get_input_embeddings()(prompt_batch.input_ids) | |
| if prot_tokens.dtype != prompt_embeds.dtype: | |
| prot_tokens = prot_tokens.to(prompt_embeds.dtype) | |
| inputs_embeds = torch.cat((prot_tokens, prompt_embeds), dim=1) | |
| prot_mask = torch.ones(prot_tokens.shape[:2], dtype=prompt_batch.attention_mask.dtype, device=prompt_embeds.device) | |
| attention_mask = torch.cat([prot_mask, prompt_batch.attention_mask], dim=1) | |
| gen_out = self.llm_model.generate( | |
| inputs_embeds=inputs_embeds, | |
| attention_mask=attention_mask, | |
| do_sample=do_sample, | |
| top_p=top_p, | |
| temperature=temperature, | |
| num_beams=num_beams, | |
| max_length=max_length, | |
| min_length=min_length, | |
| eos_token_id=self.eos_token_id, | |
| repetition_penalty=repetition_penalty, | |
| length_penalty=length_penalty, | |
| num_return_sequences=num_captions, | |
| use_cache=True, | |
| return_dict_in_generate=True, | |
| output_scores=True, | |
| output_hidden_states=True | |
| ) | |
| texts = self.llm_tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True) | |
| texts = [t.strip() for t in texts] | |
| B_gen = prot_tokens.shape[0] | |
| Q_gen = prot_tokens.shape[1] | |
| generated_token_ids = gen_out.sequences | |
| if generated_token_ids.shape[1] > 0 and (generated_token_ids != self.llm_tokenizer.pad_token_id).any(): | |
| generated_embeds = self.llm_model.get_input_embeddings()(generated_token_ids) | |
| full_embeds = torch.cat([prot_tokens, prompt_embeds, generated_embeds], dim=1) | |
| gen_mask = torch.ones((B_gen, generated_token_ids.shape[1]), | |
| dtype=attention_mask.dtype, device=device) | |
| full_attention_mask = torch.cat([attention_mask, gen_mask], dim=1) | |
| with torch.no_grad(): | |
| outputs = self.llm_model( | |
| inputs_embeds=full_embeds, | |
| attention_mask=full_attention_mask, | |
| output_hidden_states=True, | |
| return_dict=True | |
| ) | |
| last_hidden = outputs.hidden_states[-1] | |
| with autocast(enabled=False): | |
| head_param = next(self.reliability_head.parameters()) | |
| qf = qformer_feats.to(device=head_param.device, dtype=head_param.dtype) | |
| pf = plm_pooled.to(device=head_param.device, dtype=head_param.dtype) | |
| r_input = torch.cat([qf, pf], dim=-1) | |
| r_logits = self.reliability_head(r_input) # [B, K] | |
| r_probs = torch.softmax(r_logits, dim=-1) # [B, K] | |
| pred_idx = r_logits.argmax(dim=-1) | |
| if self.reliability_binary: | |
| r_pred = pred_idx.to(torch.float32) # 0.0 (neg) or 1.0 (pos) | |
| else: | |
| r_pred = _class_to_r_value(pred_idx) | |
| # LLM last layer embedding: mean over sequence (protein + prompt + generated) | |
| llm_last_mean = last_hidden.to(torch.float32).mean(dim=1).detach().to(torch.float16).cpu() | |
| else: | |
| r_pred = torch.ones(B_gen, device=device) * (0.0 if self.reliability_binary else 0.5) | |
| r_probs = torch.zeros(B_gen, 2 if self.reliability_binary else RELIABILITY_NUM_CLASSES, device=device) | |
| # LLM last layer embedding: forward on input only (protein + prompt), then mean over sequence | |
| with torch.no_grad(): | |
| outputs = self.llm_model( | |
| inputs_embeds=inputs_embeds, | |
| attention_mask=attention_mask, | |
| output_hidden_states=True, | |
| return_dict=True | |
| ) | |
| last_hidden = outputs.hidden_states[-1] | |
| llm_last_mean = last_hidden.to(torch.float32).mean(dim=1).detach().to(torch.float16).cpu() | |
| ts = self.llm_model.compute_transition_scores( | |
| gen_out.sequences, gen_out.scores, | |
| beam_indices=getattr(gen_out, "beam_indices", None), | |
| normalize_logits=True | |
| ) | |
| conf = torch.exp(ts.to(torch.float32).mean(dim=1)).clamp(1e-9, 1.0).tolist() | |
| emb_out = { | |
| "plm_mean_fp16": prot_embeds.to(torch.float32).mean(dim=1).detach().to(torch.float16).cpu(), | |
| "qformer_feats_fp16": prot_feats.detach().to(torch.float16).cpu(), | |
| "llm_last_fp16": llm_last_mean | |
| } | |
| return texts, r_pred, conf, emb_out, r_probs | |