""" Copyright (c) 2023, salesforce.com, inc. All rights reserved. SPDX-License-Identifier: BSD-3-Clause For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause """ import logging import os import torch import numpy as np import torch.nn as nn import torch.nn.functional as F from collections.abc import Mapping from torch.cuda.amp import autocast as autocast from lavis.models.blip2_models.blip2 import disabled_train from model.blip2 import Blip2Base from transformers import AutoTokenizer from transformers import OPTForCausalLM from opendelta import LoraModel from torch.cuda.amp import autocast from opendelta.delta_models.lora import LoraConfig as DeltaLoraConfig from transformers import BertTokenizer, BitsAndBytesConfig from transformers.tokenization_utils_base import BatchEncoding from model.help_funcs import hf_enable_gradient_checkpointing opt_model_list = [ "facebook/galactica-125m", "facebook/galactica-1.3b", "facebook/galactica-6.7b", "facebook/galactica-30b", ] def get_gpu_memory(device=0): free, total = torch.cuda.mem_get_info(device) free = free / (1024 ** 3) total = total / (1024 ** 3) return free, total-free, total # Reliability head: 4-class classification mapping # Class indices: {-1: 0, 0: 1, 0.5: 2, 1: 3} RELIABILITY_CLASSES = [-1.0, 0.0, 0.5, 1.0] RELIABILITY_VAL_TO_IDX = {v: i for i, v in enumerate(RELIABILITY_CLASSES)} RELIABILITY_NUM_CLASSES = len(RELIABILITY_CLASSES) def _r_value_to_class(r_tensor): """Convert reliability float values to class indices.""" class_indices = torch.zeros_like(r_tensor, dtype=torch.long) for val, idx in RELIABILITY_VAL_TO_IDX.items(): class_indices[torch.isclose(r_tensor, torch.tensor(val, device=r_tensor.device, dtype=r_tensor.dtype), atol=1e-4)] = idx return class_indices def _class_to_r_value(class_indices): """Convert class indices back to reliability float values.""" classes = torch.tensor(RELIABILITY_CLASSES, dtype=torch.float32, device=class_indices.device) return classes[class_indices] def _r_value_to_binary_class(r_tensor): """Binary mapping: 1 if r ~= 1.0 (positive class), else 0.""" pos = torch.isclose(r_tensor, torch.tensor(1.0, device=r_tensor.device, dtype=r_tensor.dtype), atol=1e-4) return pos.long() def _scan_r_counts(json_path): """Return (raw_counts_4class, total). raw_counts indexed by RELIABILITY_VAL_TO_IDX.""" import json as _json counts = torch.zeros(RELIABILITY_NUM_CLASSES, dtype=torch.float32) if not json_path or not os.path.isfile(json_path): return counts, 0 total = 0 with open(json_path, 'r', encoding='utf-8') as f: for line in f: line = line.strip() if not line: continue row = _json.loads(line) r = float(row[2]) total += 1 for cls_val, idx in RELIABILITY_VAL_TO_IDX.items(): if abs(r - cls_val) < 1e-4: counts[idx] += 1 break return counts, total def _inverse_freq_from_counts(counts): """Inverse-frequency weights w_k = N / (K_present * n_k), mean weight over present classes = 1. Empty classes get weight 0.""" if counts.sum() == 0: return None nonzero = counts > 0 weights = torch.zeros_like(counts) weights[nonzero] = counts.sum() / (nonzero.sum().float() * counts[nonzero]) return weights def _compute_inverse_freq_class_weights(json_path): """4-class inverse-frequency weights from JSON-lines file.""" counts, total = _scan_r_counts(json_path) if total == 0: return None return _inverse_freq_from_counts(counts) def _compute_binary_inverse_freq_class_weights(json_path): """Binary inverse-frequency weights. Index 0 = negative (r != 1), index 1 = positive (r == 1).""" counts, total = _scan_r_counts(json_path) if total == 0: return None pos = counts[RELIABILITY_VAL_TO_IDX[1.0]] neg = counts.sum() - pos binary_counts = torch.tensor([neg, pos], dtype=torch.float32) return _inverse_freq_from_counts(binary_counts) def mask_by_len(input, lens, fill_value=0): ''' input: shape = [N, D] lens: shape = [N] ''' mask = torch.arange(input.shape[1], device=input.device).reshape(1, -1) mask = mask < lens.reshape(-1, 1) input[mask] = fill_value return input class Blip2OPT(Blip2Base): """ BLIP2 model for protein function prediction with Q-former (generation only). """ def __init__( self, bert_name, num_query_token=32, cross_attention_freq=2, plm_model="facebook/esm2_t30_150M_UR50D", plm_tune='freeze', llm_name="facebook/galactica-1.3b", llm_tune='freeze', peft_dir='', args=None, ): super().__init__() self.args = args self.id2go = getattr(args, 'id2go', {}) self.do_sample = args.do_sample self.num_beams = args.num_beams self.max_inference_len = args.max_inference_len self.min_inference_len = args.min_inference_len self.enbale_gradient_checkpointing = args.enbale_gradient_checkpointing self.plm_lora_r = args.plm_lora_r self.plm_lora_alpha = args.plm_lora_alpha self.plm_lora_dropout = args.plm_lora_dropout # Detect encoder type based on model name self.plm_model = plm_model self.is_esmc = str(plm_model).startswith('esmc_') self.plm_tokenizer, self.plm, self.ln_layer = self.init_protein_encoder(plm_model) self.plm_tune = plm_tune if plm_tune == 'freeze': for name, param in self.plm.named_parameters(): param.requires_grad = False self.plm = self.plm.eval() self.plm.train = disabled_train logging.info("freeze plm encoder") elif plm_tune == 'lora': plm_cfg = DeltaLoraConfig(self.plm_lora_r, self.plm_lora_alpha, self.plm_lora_dropout, modified_modules=["attn.layernorm_qkv.1","ffn.1"]) self.plm_delta = LoraModel.from_config(plm_cfg, self.plm) self.plm_delta.freeze_module(set_state_dict=False) self.plm_delta.log() else: raise NotImplementedError() self.num_query_token = num_query_token _, self.Qformer, self.query_tokens = self.init_Qformer(bert_name, num_query_token, self.plm.num_features, cross_attention_freq) self.Qformer.cls = None self.Qformer.bert.embeddings.word_embeddings = None self.Qformer.bert.embeddings.position_embeddings = None for layer in self.Qformer.bert.encoder.layer: layer.output = None layer.intermediate = None for p in self.ln_layer.parameters(): p.requires_grad = False for p in self.plm.parameters(): p.requires_grad = False self.llm_model, self.llm_tokenizer = self.load_llm(llm_name, load_4bit=False, enable_gradient_checkpointing=self.enbale_gradient_checkpointing) if llm_tune == 'freeze': for name, param in self.llm_model.named_parameters(): param.requires_grad = False elif llm_tune == 'full': for name, param in self.llm_model.named_parameters(): param.requires_grad = True elif llm_tune == 'lora': lora_config = DeltaLoraConfig(args.lora_r, args.lora_alpha, args.lora_dropout, ) self.lm_delta = LoraModel.from_config(lora_config, self.llm_model) self.lm_delta.freeze_module(set_state_dict=False) self.lm_delta.log() elif llm_tune == 'mid_lora': lora_config = DeltaLoraConfig(args.lora_r, args.lora_alpha, args.lora_dropout, modified_modules=["q_proj", "v_proj", 'k_proj', "out_proj", "fc1", "fc2"]) self.delta = LoraModel.from_config(lora_config, self.llm_model) self.delta.freeze_module(set_state_dict=False) self.delta.log() elif llm_tune == 'peft_lora': config = PeftLoraConfig( r=args.lora_r, lora_alpha=args.lora_alpha, # target_modules=modules, lora_dropout=args.lora_dropout, bias="none", task_type="CAUSAL_LM", ) self.llm_model = get_peft_model(self.llm_model, config) for name, module in self.llm_model.named_modules(): if isinstance(module, LoraLayer): if True: module = module.to(torch.bfloat16) if 'norm' in name: module = module.to(torch.float32) if 'lm_head' in name or 'embed_tokens' in name: if hasattr(module, 'weight'): if True and module.weight.dtype == torch.float32: module = module.to(torch.bfloat16) else: raise NotImplementedError() self.eos_token_id = self.llm_tokenizer( "\n", add_special_tokens=False ).input_ids[0] self.opt_proj = nn.Linear(self.Qformer.config.hidden_size, self.llm_model.config.hidden_size) # Reliability head can be either 4-class ({-1, 0, 0.5, 1}) or binary (positive=r==1, else negative). self.reliability_binary = bool(getattr(args, 'reliability_binary', False)) head_out_dim = 2 if self.reliability_binary else RELIABILITY_NUM_CLASSES reliability_input_dim = self.Qformer.config.hidden_size + self.plm.num_features self.reliability_head = nn.Sequential( nn.Dropout(0.3), nn.Linear(reliability_input_dim, head_out_dim), ) self.reliability_head = self.reliability_head.to(torch.float32) # Class weights for reliability cross-entropy. Computed from the reliability # finetune JSON when train_reliability_head_only is set; uniform otherwise. cw = None if getattr(args, 'train_reliability_head_only', False): json_path = getattr(args, 'reliability_finetune_data', '') if self.reliability_binary: cw = _compute_binary_inverse_freq_class_weights(json_path) else: cw = _compute_inverse_freq_class_weights(json_path) if cw is not None: mode_tag = 'binary' if self.reliability_binary else '4-class' logging.info(f"reliability {mode_tag} inverse-freq class weights: {cw.tolist()}") if cw is None: cw = torch.ones(head_out_dim, dtype=torch.float32) self.register_buffer('reliability_class_weights', cw) def load_llm(self, llm_model, load_4bit=False, enable_gradient_checkpointing=True): llm_tokenizer = AutoTokenizer.from_pretrained(llm_model, use_fast=False, padding_side='right') llm_tokenizer.add_special_tokens({'pad_token': ''}) if load_4bit: quant_config = BitsAndBytesConfig( load_in_4bit=True, load_in_8bit=False, llm_int8_threshold=6.0, llm_int8_has_fp16_weight=False, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type='nf4', ) import os visible_devices = os.environ.get('CUDA_VISIBLE_DEVICES', '0') device_ids = [int(d) for d in visible_devices.split(',') if d.strip()] outputs = get_gpu_memory(device_ids[0] if device_ids else 0) used_memory = outputs[1] if used_memory > 1 and len(device_ids) > 1: device_map = {"": device_ids[1]} else: device_map = {"": device_ids[0] if device_ids else 0} llm_model = OPTForCausalLM.from_pretrained( llm_model, quantization_config=quant_config, load_in_4bit=True, load_in_8bit=False, device_map=device_map, torch_dtype=torch.bfloat16, ) llm_model.resize_token_embeddings(len(llm_tokenizer)) llm_model = prepare_model_for_kbit_training(llm_model, use_gradient_checkpointing=True) else: llm_model = OPTForCausalLM.from_pretrained(llm_model, torch_dtype=torch.bfloat16) llm_model.resize_token_embeddings(len(llm_tokenizer)) if enable_gradient_checkpointing: llm_model = hf_enable_gradient_checkpointing(llm_model) return llm_model, llm_tokenizer def _normalize_batch(self, batch): """Normalize batch to dict format.""" if isinstance(batch, (Mapping, BatchEncoding)): return batch if isinstance(batch, (list, tuple)): for x in batch: if isinstance(x, (Mapping, BatchEncoding)) and "input_ids" in x: return x raise TypeError(f"Batch must be a mapping with input_ids; got {type(batch)}") def prot_encode(self, prot_batch): """Encode protein sequence to embeddings.""" prot_batch = self._normalize_batch(prot_batch) device = next(self.plm.parameters()).device raw_ids = prot_batch.get("input_ids") raw_mask = prot_batch.get("attention_mask", None) pad_id = getattr(self.plm_tokenizer, "pad_token_id", 1) if torch.is_tensor(raw_ids): input_ids = raw_ids.to(device) attention_mask = raw_mask.to(device) if torch.is_tensor(raw_mask) else None elif isinstance(raw_ids, np.ndarray): input_ids = torch.as_tensor(raw_ids, dtype=torch.long, device=device) attention_mask = torch.as_tensor(raw_mask, dtype=torch.long, device=device) if raw_mask is not None else None elif isinstance(raw_ids, list): ids_list = [raw_ids] if (not raw_ids or not isinstance(raw_ids[0], list)) else raw_ids Lfix = getattr(self, "prot_max_len", None) or max(len(x) for x in ids_list) padded, masks = [], [] for ids in ids_list: curr_len = len(ids) ids = ids[:Lfix] + [pad_id] * max(0, Lfix - curr_len) padded.append(ids) masks.append([1] * min(curr_len, Lfix) + [0] * max(0, Lfix - curr_len)) input_ids = torch.tensor(padded, dtype=torch.long, device=device) attention_mask = torch.tensor(masks, dtype=torch.long, device=device) else: raise TypeError(f"Unsupported input_ids type: {type(raw_ids)}") if attention_mask is None: attention_mask = (input_ids != pad_id).long() out = self.plm(input_ids=input_ids, attention_mask=attention_mask) prot_embeds = out.last_hidden_state if hasattr(out, "last_hidden_state") else out[0] if self.plm_tune == "freeze": prot_embeds = prot_embeds.detach() ln_dtype = next(self.ln_layer.parameters()).dtype if prot_embeds.dtype != ln_dtype: prot_embeds = prot_embeds.to(ln_dtype) return self.ln_layer(prot_embeds), attention_mask def prot_qformer(self, prot_embeds, attention_mask): """Apply Q-former to protein embeddings.""" B = prot_embeds.size(0) query_tokens = self.query_tokens.expand(B, -1, -1) q_out = self.Qformer.bert( query_embeds=query_tokens, encoder_hidden_states=prot_embeds, encoder_attention_mask=attention_mask, return_dict=True, ) qformer_raw = q_out.last_hidden_state # [B, num_query, qformer_hidden] prot_tokens = self.opt_proj(qformer_raw) prot_feats = F.normalize(prot_tokens.mean(dim=1), dim=-1, p=2) # Mean-pooled Q-Former features before projection (for reliability head) qformer_feats = qformer_raw.mean(dim=1).detach() # Mean-pooled PLM embeddings (masked) for reliability head plm_pooled = (prot_embeds * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(dim=1, keepdim=True).clamp(min=1) plm_pooled = plm_pooled.detach() return prot_tokens, prot_feats, qformer_feats, plm_pooled def _extract_reliability_features(self, last_hidden, attention_mask, prot_feats, protein_token_count): """ Extract features for reliability prediction: masked mean of LLM last hidden layer. Args: last_hidden: Hidden states from LLM [B, seq_len, hidden_size] attention_mask: Attention mask for valid tokens [B, seq_len] prot_feats: (unused, kept for API compatibility) protein_token_count: (unused, kept for API compatibility) Returns: Mean-pooled feature tensor [B, hidden_size] """ last_hidden_fp32 = last_hidden.to(torch.float32) mask_expanded = attention_mask.unsqueeze(-1).to(torch.float32) hidden_masked = last_hidden_fp32 * mask_expanded seq_lengths = mask_expanded.sum(dim=1).clamp(min=1) mean_pool = hidden_masked.sum(dim=1) / seq_lengths return mean_pool.detach() def forward(self, batch, return_pred=False): """Forward pass for training (generation + reliability classification).""" prot_batch, text_batch, prompt_batch, r_tensor = batch samples = {'prot_batch': prot_batch, 'prompt_batch': prompt_batch, 'reliability': r_tensor} prot_embeds, prot_attn = self.prot_encode(prot_batch) prot_tokens, prot_feats, qformer_feats, plm_pooled = self.prot_qformer(prot_embeds, prot_attn) device = prot_embeds.device text_batch = self._normalize_batch(text_batch) input_ids = text_batch["input_ids"].to(device) attention_mask = text_batch["attention_mask"].to(device) token_type_ids = text_batch.get("token_type_ids") targets = input_ids.masked_fill(input_ids == self.llm_tokenizer.pad_token_id, -100) if token_type_ids is not None: targets = targets.masked_fill(token_type_ids.to(device) == 0, -100) B, Q = prot_tokens.shape[:2] prot_mask = torch.ones((B, Q), dtype=attention_mask.dtype, device=device) prot_empty_targets = torch.full((B, Q), -100, dtype=torch.long, device=device) text_embeds = self.llm_model.get_input_embeddings()(input_ids) if prot_tokens.dtype != text_embeds.dtype: prot_tokens = prot_tokens.to(text_embeds.dtype) inputs_embeds = torch.cat([prot_tokens, text_embeds], dim=1) full_attn = torch.cat([prot_mask, attention_mask], dim=1) targets = torch.cat([prot_empty_targets, targets], dim=1) outputs = self.llm_model( inputs_embeds=inputs_embeds, attention_mask=full_attn, return_dict=True, labels=targets, output_hidden_states=True ) last_hidden = outputs.hidden_states[-1] last_hidden = last_hidden[0] if isinstance(last_hidden, (tuple, list)) else last_hidden with autocast(enabled=False): head_param = next(self.reliability_head.parameters()) qf = qformer_feats.to(device=head_param.device, dtype=head_param.dtype) pf = plm_pooled.to(device=head_param.device, dtype=head_param.dtype) r_input = torch.cat([qf, pf], dim=-1) r_logits = self.reliability_head(r_input) # [B, NUM_CLASSES] or [B, 2] r_target = r_tensor.to(device=r_logits.device, dtype=r_logits.dtype).view(-1) if self.reliability_binary: r_class_target = _r_value_to_binary_class(r_target).to(r_logits.device) else: r_class_target = _r_value_to_class(r_target).to(r_logits.device) cw = self.reliability_class_weights.to(device=r_logits.device, dtype=r_logits.dtype) r_loss = nn.functional.cross_entropy(r_logits, r_class_target, weight=cw) if return_pred: with torch.no_grad(): pred_texts, r_pred, conf, _, _prob = self.generate( samples, do_sample=self.do_sample, num_beams=self.num_beams, max_length=self.max_inference_len, min_length=self.min_inference_len ) return outputs.loss, r_loss, pred_texts, r_pred return outputs.loss, r_loss @torch.no_grad() def generate( self, samples, do_sample=False, num_beams=3, max_length=128, min_length=1, top_p=0.9, repetition_penalty=1.3, length_penalty=1.0, num_captions=1, temperature=1, ): """ Args: samples (dict): A dictionary containing the following keys: - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W) num_beams (int): Number of beams for beam search. 1 means no beam search. max_length (int): The maximum length of the sequence to be generated. min_length (int): The minimum length of the sequence to be generated. top_p (float): The cumulative probability for nucleus sampling. repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty. num_captions (int): Number of captions to be generated for each image. Returns: captions (list): A list of strings of length batch_size * num_captions. """ prot_batch = samples['prot_batch'] prompt_batch = samples['prompt_batch'] # Encode protein sequence prot_embeds, prot_attn = self.prot_encode(prot_batch) prot_tokens, prot_feats, qformer_feats, plm_pooled = self.prot_qformer(prot_embeds, prot_attn) device = prot_embeds.device prompt_embeds = self.llm_model.get_input_embeddings()(prompt_batch.input_ids) if prot_tokens.dtype != prompt_embeds.dtype: prot_tokens = prot_tokens.to(prompt_embeds.dtype) inputs_embeds = torch.cat((prot_tokens, prompt_embeds), dim=1) prot_mask = torch.ones(prot_tokens.shape[:2], dtype=prompt_batch.attention_mask.dtype, device=prompt_embeds.device) attention_mask = torch.cat([prot_mask, prompt_batch.attention_mask], dim=1) gen_out = self.llm_model.generate( inputs_embeds=inputs_embeds, attention_mask=attention_mask, do_sample=do_sample, top_p=top_p, temperature=temperature, num_beams=num_beams, max_length=max_length, min_length=min_length, eos_token_id=self.eos_token_id, repetition_penalty=repetition_penalty, length_penalty=length_penalty, num_return_sequences=num_captions, use_cache=True, return_dict_in_generate=True, output_scores=True, output_hidden_states=True ) texts = self.llm_tokenizer.batch_decode(gen_out.sequences, skip_special_tokens=True) texts = [t.strip() for t in texts] B_gen = prot_tokens.shape[0] Q_gen = prot_tokens.shape[1] generated_token_ids = gen_out.sequences if generated_token_ids.shape[1] > 0 and (generated_token_ids != self.llm_tokenizer.pad_token_id).any(): generated_embeds = self.llm_model.get_input_embeddings()(generated_token_ids) full_embeds = torch.cat([prot_tokens, prompt_embeds, generated_embeds], dim=1) gen_mask = torch.ones((B_gen, generated_token_ids.shape[1]), dtype=attention_mask.dtype, device=device) full_attention_mask = torch.cat([attention_mask, gen_mask], dim=1) with torch.no_grad(): outputs = self.llm_model( inputs_embeds=full_embeds, attention_mask=full_attention_mask, output_hidden_states=True, return_dict=True ) last_hidden = outputs.hidden_states[-1] with autocast(enabled=False): head_param = next(self.reliability_head.parameters()) qf = qformer_feats.to(device=head_param.device, dtype=head_param.dtype) pf = plm_pooled.to(device=head_param.device, dtype=head_param.dtype) r_input = torch.cat([qf, pf], dim=-1) r_logits = self.reliability_head(r_input) # [B, K] r_probs = torch.softmax(r_logits, dim=-1) # [B, K] pred_idx = r_logits.argmax(dim=-1) if self.reliability_binary: r_pred = pred_idx.to(torch.float32) # 0.0 (neg) or 1.0 (pos) else: r_pred = _class_to_r_value(pred_idx) # LLM last layer embedding: mean over sequence (protein + prompt + generated) llm_last_mean = last_hidden.to(torch.float32).mean(dim=1).detach().to(torch.float16).cpu() else: r_pred = torch.ones(B_gen, device=device) * (0.0 if self.reliability_binary else 0.5) r_probs = torch.zeros(B_gen, 2 if self.reliability_binary else RELIABILITY_NUM_CLASSES, device=device) # LLM last layer embedding: forward on input only (protein + prompt), then mean over sequence with torch.no_grad(): outputs = self.llm_model( inputs_embeds=inputs_embeds, attention_mask=attention_mask, output_hidden_states=True, return_dict=True ) last_hidden = outputs.hidden_states[-1] llm_last_mean = last_hidden.to(torch.float32).mean(dim=1).detach().to(torch.float16).cpu() ts = self.llm_model.compute_transition_scores( gen_out.sequences, gen_out.scores, beam_indices=getattr(gen_out, "beam_indices", None), normalize_logits=True ) conf = torch.exp(ts.to(torch.float32).mean(dim=1)).clamp(1e-9, 1.0).tolist() emb_out = { "plm_mean_fp16": prot_embeds.to(torch.float32).mean(dim=1).detach().to(torch.float16).cpu(), "qformer_feats_fp16": prot_feats.detach().to(torch.float16).cpu(), "llm_last_fp16": llm_last_mean } return texts, r_pred, conf, emb_out, r_probs