Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| from tqdm import tqdm | |
| from model.blip2_opt import Blip2OPT | |
| import pytorch_lightning as pl | |
| from torch import optim | |
| from lavis.common.optims import LinearWarmupCosineLRScheduler, LinearWarmupStepLRScheduler | |
| import json | |
| import ast | |
| import pickle | |
| from evals.tools.InfoAccretion import compute_InfoAccretion_distance | |
| from evals.tools.wang_similarity import compute_wang_similarity | |
| from evals.tools.jaccard_similarity import compute_jaccard_similarity | |
| from evals.tools.extraction import process_texts_with_api | |
| import numpy as np | |
| import torch.distributed as dist | |
| from typing import Any, Dict | |
| from model.help_funcs import ( | |
| caption_evaluate, | |
| AttrDict, | |
| _mean_conf, | |
| _json_default, | |
| load_or_process, | |
| load_mf_go_ids_from_tsv, | |
| filter_go_terms_by_set, | |
| build_joint_nonempty_mask, | |
| filter_parallel_by_mask, | |
| ) | |
| def _batch_to_device(x, device): | |
| """Move batch (dict/tensor/list) to device; skip non-tensor values (e.g. lists of ints from ESM-C).""" | |
| if torch.is_tensor(x): | |
| return x.to(device) | |
| if isinstance(x, dict): | |
| return {k: _batch_to_device(v, device) for k, v in x.items()} | |
| if isinstance(x, (list, tuple)): | |
| return type(x)(_batch_to_device(v, device) if torch.is_tensor(v) else v for v in x) | |
| return x | |
| class Blip2Stage2(pl.LightningModule): | |
| def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: | |
| pass | |
| def __init__(self, args): | |
| super().__init__() | |
| if isinstance(args, dict): | |
| args = AttrDict(**args) | |
| self.args = args | |
| self.caption_eval_epoch = args.caption_eval_epoch | |
| self.do_sample = args.do_sample | |
| self.num_beams = args.num_beams | |
| self.max_inference_len = args.max_inference_len | |
| self.min_inference_len = args.min_inference_len | |
| self.llm_tune = args.llm_tune | |
| # GO term extraction on test set (dataloader_idx 1) | |
| self.report_go_wang_on_test = getattr(args, 'report_go_wang_on_test', False) | |
| self.ia_path = getattr(args, 'ia_path', 'evals/tools/IA.txt') | |
| self.test_set_path = getattr(args, 'test_set_path', '') or os.path.join(getattr(args, 'root', 'data/SwissProtV3'), 'test_set.json') | |
| self.valid_set_path = getattr(args, 'valid_set_path', '') or os.path.join(getattr(args, 'root', 'data/SwissProtV3'), 'valid_set.json') | |
| self.go_files_tsv_path = getattr(args, 'go_files_tsv_path', 'evals/tools/go_files.tsv') | |
| # On last epoch: extract GO from val predictions and compute Wang vs valid_set.json (default off) | |
| self.report_go_wang_on_val = getattr(args, 'report_go_wang_on_val', False) | |
| # Prediction collection and saving parameters | |
| self.save_predictions = getattr(args, 'save_predictions', False) | |
| self.inference_on_training_data = getattr(args, 'inference_on_training_data', False) | |
| self.train_reliability_head_only = getattr(args, 'train_reliability_head_only', False) | |
| # Validate encoder_type and plm_model consistency | |
| encoder_type = getattr(args, 'encoder_type', 'auto') | |
| if encoder_type != 'auto': | |
| if encoder_type == 'esm2' and not args.plm_model.startswith('facebook/esm2'): | |
| raise ValueError(f"encoder_type='{encoder_type}' but plm_model='{args.plm_model}' does not start with 'facebook/esm2'") | |
| elif encoder_type == 'esmc' and not args.plm_model.startswith('esmc_'): | |
| raise ValueError(f"encoder_type='{encoder_type}' but plm_model='{args.plm_model}' does not start with 'esmc_'") | |
| if args.llm_name.find('galactica') >= 0: | |
| self.blip2 = Blip2OPT(args.bert_name, | |
| args.num_query_token, | |
| args.cross_attention_freq, | |
| args.plm_model, | |
| args.plm_tune, | |
| args.llm_name, | |
| args.llm_tune, | |
| args.peft_dir, | |
| args) | |
| else: | |
| raise NotImplementedError() | |
| # Default training: freeze ESM (base + LoRA), reliability_head, ln_layer, Qformer; train LLM (decoder) only | |
| if not self.inference_on_training_data and not self.train_reliability_head_only: | |
| for name, param in self.blip2.named_parameters(): | |
| if 'reliability_head' in name or 'plm' in name: | |
| param.requires_grad = False | |
| self.save_hyperparameters(args) | |
| def load_from_stage1_checkpoint(self, path): | |
| ckpt = torch.load(path, map_location='cpu') | |
| state_dict = ckpt['state_dict'] | |
| state_dict = {k.split('blip2qformer.')[1]:v for k, v in state_dict.items()} | |
| self.blip2.load_state_dict(state_dict, strict=False) | |
| return self | |
| def freeze_for_reliability_finetune(self): | |
| """Freeze all parameters except reliability_head. Call after loading checkpoint for train_reliability_head_only.""" | |
| for name, param in self.named_parameters(): | |
| param.requires_grad = 'reliability_head' in name | |
| def run_inference_on_training_subset( | |
| self, | |
| dm, | |
| output_path, | |
| min_go=2, | |
| sample_size=2000, | |
| seed=42, | |
| reliability_label_zero=False, | |
| ): | |
| """ | |
| Run inference on training subset (>= min_go GO terms, up to sample_size). | |
| If reliability_label_zero: write r=0 for all rows (no GO extraction). | |
| Else: extract GO from predictions, compute Wang similarity, replace r by that score. | |
| Save rows to output_path. | |
| """ | |
| dataloader = dm.get_inference_training_dataloader(min_go=min_go, sample_size=sample_size, seed=seed) | |
| self.eval() | |
| device = next(self.parameters()).device | |
| if device.type == 'cpu' and torch.cuda.is_available(): | |
| self.to('cuda') | |
| device = next(self.parameters()).device | |
| idx_to_pred = [] | |
| n_batches = len(dataloader) | |
| print(f"Inference on training subset: {n_batches} batches (sample_size={sample_size})", flush=True) | |
| for batch in tqdm(dataloader, total=n_batches, desc="train_inference", unit="batch"): | |
| prot_tokens, prompt_tokens, r_tensor, target_dict = batch | |
| prot_tokens = _batch_to_device(prot_tokens, device) | |
| if hasattr(prompt_tokens, 'to'): | |
| prompt_tokens = prompt_tokens.to(device) | |
| else: | |
| prompt_tokens = type(prompt_tokens)({k: v.to(device) for k, v in prompt_tokens.items()}) | |
| r_tensor = r_tensor.to(device) | |
| samples = {'prot_batch': prot_tokens, 'prompt_batch': prompt_tokens, 'reliability': r_tensor} | |
| pred_texts, _, _, _, _ = self.blip2.generate( | |
| samples, | |
| do_sample=self.do_sample, | |
| num_beams=self.num_beams, | |
| max_length=self.max_inference_len, | |
| min_length=self.min_inference_len, | |
| ) | |
| indices = target_dict['indices'] | |
| if hasattr(indices, 'tolist'): | |
| indices = indices.tolist() | |
| for i, idx in enumerate(indices): | |
| idx_to_pred.append((idx, pred_texts[i])) | |
| idx_to_pred.sort(key=lambda x: x[0]) | |
| sorted_indices = [x[0] for x in idx_to_pred] | |
| pred_texts_ordered = [x[1] for x in idx_to_pred] | |
| with open(dm.train_dataset.data_path, 'r', encoding='utf-8') as f: | |
| train_lines = [line.strip() for line in f if line.strip()] | |
| train_rows = [json.loads(line) for line in train_lines] | |
| if reliability_label_zero: | |
| per_scores = [0.0] * len(sorted_indices) | |
| else: | |
| gt_go_list = [] | |
| for idx in sorted_indices: | |
| row = train_rows[idx] | |
| g = row[3] | |
| go_list = ast.literal_eval(g) if isinstance(g, str) else g | |
| gt_go_list.append(go_list) | |
| print(f"Extracting GO terms from {len(pred_texts_ordered)} predictions (API calls)...", flush=True) | |
| predicted_go_terms = process_texts_with_api(pred_texts_ordered) | |
| _, per_scores = compute_wang_similarity(gt_go_list, predicted_go_terms) | |
| os.makedirs(os.path.dirname(output_path) or '.', exist_ok=True) | |
| with open(output_path, 'w', encoding='utf-8') as f: | |
| for i, idx in enumerate(sorted_indices): | |
| row = list(train_rows[idx]) | |
| row[1] = pred_texts_ordered[i] # predicted text | |
| row[2] = per_scores[i] | |
| f.write(json.dumps(row, ensure_ascii=True) + '\n') | |
| r_desc = "r=0" if reliability_label_zero else "r (wang score)" | |
| print(f"Saved {len(sorted_indices)} rows with {r_desc} to {output_path}", flush=True) | |
| return output_path | |
| def run_inference_on_validation_set(self, dm, output_path, reliability_label_zero=False): | |
| """ | |
| Run inference on full validation set. | |
| If reliability_label_zero: write r=0 for all rows (no GO extraction). | |
| Else: extract GO from predictions, compute Wang similarity, replace r by that score. | |
| Save rows to output_path (same format as train). | |
| """ | |
| dataloader = dm.get_validation_inference_dataloader() | |
| valid_path = getattr(dm, 'valid_set_path', None) | |
| if not valid_path or not os.path.exists(valid_path): | |
| raise FileNotFoundError(f"Validation set path not found: {valid_path}") | |
| self.eval() | |
| device = next(self.parameters()).device | |
| if device.type == 'cpu' and torch.cuda.is_available(): | |
| self.to('cuda') | |
| device = next(self.parameters()).device | |
| idx_to_pred = [] | |
| n_batches = len(dataloader) | |
| print(f"Inference on validation set: {n_batches} batches", flush=True) | |
| for batch in tqdm(dataloader, total=n_batches, desc="val_inference", unit="batch"): | |
| prot_tokens, prompt_tokens, r_tensor, target_dict = batch | |
| prot_tokens = _batch_to_device(prot_tokens, device) | |
| if hasattr(prompt_tokens, 'to'): | |
| prompt_tokens = prompt_tokens.to(device) | |
| else: | |
| prompt_tokens = type(prompt_tokens)({k: v.to(device) for k, v in prompt_tokens.items()}) | |
| r_tensor = r_tensor.to(device) | |
| samples = {'prot_batch': prot_tokens, 'prompt_batch': prompt_tokens, 'reliability': r_tensor} | |
| pred_texts, _, _, _, _ = self.blip2.generate( | |
| samples, | |
| do_sample=self.do_sample, | |
| num_beams=self.num_beams, | |
| max_length=self.max_inference_len, | |
| min_length=self.min_inference_len, | |
| ) | |
| indices = target_dict['indices'] | |
| if hasattr(indices, 'tolist'): | |
| indices = indices.tolist() | |
| for i, idx in enumerate(indices): | |
| idx_to_pred.append((idx, pred_texts[i])) | |
| idx_to_pred.sort(key=lambda x: x[0]) | |
| sorted_indices = [x[0] for x in idx_to_pred] | |
| pred_texts_ordered = [x[1] for x in idx_to_pred] | |
| with open(valid_path, 'r', encoding='utf-8') as f: | |
| valid_lines = [line.strip() for line in f if line.strip()] | |
| valid_rows = [json.loads(line) for line in valid_lines] | |
| if reliability_label_zero: | |
| per_scores = [0.0] * len(sorted_indices) | |
| else: | |
| gt_go_list = [] | |
| for idx in sorted_indices: | |
| row = valid_rows[idx] | |
| g = row[3] | |
| go_list = ast.literal_eval(g) if isinstance(g, str) else g | |
| gt_go_list.append(go_list) | |
| print(f"Extracting GO terms from {len(pred_texts_ordered)} predictions (API calls)...", flush=True) | |
| predicted_go_terms = process_texts_with_api(pred_texts_ordered) | |
| _, per_scores = compute_wang_similarity(gt_go_list, predicted_go_terms) | |
| os.makedirs(os.path.dirname(output_path) or '.', exist_ok=True) | |
| with open(output_path, 'w', encoding='utf-8') as f: | |
| for i, idx in enumerate(sorted_indices): | |
| row = list(valid_rows[idx]) | |
| row[1] = pred_texts_ordered[i] # predicted text | |
| row[2] = per_scores[i] | |
| f.write(json.dumps(row, ensure_ascii=True) + '\n') | |
| r_desc = "r=0" if reliability_label_zero else "r (wang score)" | |
| print(f"Saved {len(sorted_indices)} validation rows with {r_desc} to {output_path}", flush=True) | |
| return output_path | |
| def configure_optimizers(self): | |
| self.trainer.fit_loop.setup_data() | |
| warmup_steps = min(len(self.trainer.train_dataloader), self.args.warmup_steps) | |
| if self.train_reliability_head_only: | |
| reliability_params = [p for n, p in self.named_parameters() if 'reliability_head' in n and p.requires_grad] | |
| reliability_lr = self.args.reliability_lr if self.args.reliability_lr is not None else self.args.init_lr | |
| optimizer = optim.AdamW(reliability_params, lr=reliability_lr, weight_decay=self.args.weight_decay) | |
| else: | |
| main_params = [] | |
| reliability_params = [] | |
| for name, param in self.named_parameters(): | |
| if not param.requires_grad: | |
| continue | |
| if 'reliability_head' in name: | |
| reliability_params.append(param) | |
| else: | |
| main_params.append(param) | |
| reliability_lr = self.args.reliability_lr if self.args.reliability_lr is not None else self.args.init_lr | |
| optimizer = optim.AdamW([ | |
| {'params': main_params, 'lr': self.args.init_lr, 'weight_decay': self.args.weight_decay}, | |
| {'params': reliability_params, 'lr': reliability_lr, 'weight_decay': self.args.weight_decay} | |
| ]) | |
| if self.args.scheduler == 'linear_warmup_cosine_lr': | |
| self.scheduler = LinearWarmupCosineLRScheduler(optimizer, self.args.max_epochs, self.args.min_lr, self.args.init_lr, warmup_steps, self.args.warmup_lr) | |
| elif self.args.scheduler == 'linear_warmup_step_lr': | |
| self.scheduler = LinearWarmupStepLRScheduler(optimizer, self.args.max_epochs, self.args.min_lr, self.args.init_lr, self.args.lr_decay_rate, self.args.warmup_lr, warmup_steps) | |
| elif self.args.scheduler == 'None': | |
| self.scheduler = None | |
| else: | |
| raise NotImplementedError() | |
| return optimizer | |
| def save_predictions(self, predictions, targets, q_types=None, log_prefix=''): | |
| assert len(predictions) == len(targets) | |
| if log_prefix: | |
| name = f'{log_prefix}_predictions.txt' | |
| else: | |
| name = 'predictions.txt' | |
| with open(os.path.join(self.logger.log_dir, name), 'w', encoding='utf8') as f: | |
| if q_types is not None: | |
| for p, t, q in zip(predictions, targets, q_types): | |
| line = {'prediction': p, 'target': t, 'q_type': q} | |
| f.write(json.dumps(line, ensure_ascii=True) + '\n') | |
| else: | |
| for p, t in zip(predictions, targets): | |
| line = {'prediction': p, 'target': t} | |
| f.write(json.dumps(line, ensure_ascii=True) + '\n') | |
| def on_validation_epoch_start(self) -> None: | |
| self.saved_dict_list = [] | |
| self.prediction_list0 = [] | |
| self.target_list0 = [] | |
| self.prediction_list1 = [] | |
| self.target_list1 = [] | |
| self._saved_even_list = [] | |
| self.val_saved_list_for_go = [] | |
| def validation_step(self, batch, batch_idx, dataloader_idx=0): | |
| if (dataloader_idx % 2) == 0: | |
| text_batch = batch[1] | |
| batch_size = text_batch.input_ids.shape[0] | |
| blip_batch = batch[:4] | |
| loss, r_loss, pred_texts, r_pred = self.blip2(blip_batch, return_pred=True) | |
| idx_list = batch[4].detach().cpu().tolist() | |
| if not isinstance(idx_list, list): | |
| idx_list = [idx_list] | |
| saved_dict = {'indices': idx_list, 'predictions': pred_texts} | |
| if self.train_reliability_head_only: | |
| r_gt = batch[3] | |
| r_pred_list = r_pred.cpu().tolist() if torch.is_tensor(r_pred) else list(r_pred) | |
| r_gt_list = r_gt.cpu().tolist() if torch.is_tensor(r_gt) else list(r_gt) | |
| if not isinstance(r_pred_list, list): | |
| r_pred_list = [r_pred_list] | |
| if not isinstance(r_gt_list, list): | |
| r_gt_list = [r_gt_list] | |
| saved_dict['r_pred'] = r_pred_list | |
| saved_dict['r_gt'] = r_gt_list | |
| saved_dict['dataloader_idx'] = [dataloader_idx] * len(r_pred_list) | |
| self.val_saved_list_for_go.append(saved_dict) | |
| self.log(f"dataloader{dataloader_idx}/val_loss", loss, | |
| on_step=False, on_epoch=True, prog_bar=True, | |
| sync_dist=True, batch_size=batch_size) | |
| self.log(f"dataloader{dataloader_idx}/reliability_loss", r_loss, | |
| on_step=False, on_epoch=True, prog_bar=False, | |
| sync_dist=True, batch_size=batch_size) | |
| if self.train_reliability_head_only: | |
| self.log("val/reliability_loss", r_loss, | |
| on_step=False, on_epoch=True, prog_bar=False, | |
| sync_dist=True, batch_size=batch_size) | |
| elif (dataloader_idx % 2) == 1: | |
| # Test set: collect predictions for BLEU/ROUGE and (if --report_go_wang_on_test) GO extraction + Wang | |
| if (self.current_epoch+1) % self.caption_eval_epoch != 0: | |
| return | |
| prot_batch, prompt_batch, r_tensor, target_dict = batch | |
| samples = {'prot_batch': prot_batch, 'prompt_batch': prompt_batch, 'reliability': r_tensor} | |
| predictions, r_pred, avg_conf, emb_out, r_probs = self.blip2.generate( | |
| samples, | |
| do_sample=self.do_sample, | |
| num_beams=self.num_beams, | |
| max_length=self.max_inference_len, | |
| min_length=self.min_inference_len | |
| ) | |
| target_dict['predictions'] = predictions | |
| target_dict['confidences'] = [round(float(x), 4) for x in avg_conf] | |
| target_dict['reliability'] = [round(float(x), 4) for x in r_pred] | |
| if getattr(self.blip2, 'reliability_binary', False): | |
| # r_probs shape [B, 2]: index 0 = negative, 1 = positive (r==1). | |
| target_dict['reliability_prob_class1'] = [round(float(x), 4) for x in r_probs[:, 1]] | |
| else: | |
| from model.blip2_opt import RELIABILITY_VAL_TO_IDX | |
| for cls_val, idx in RELIABILITY_VAL_TO_IDX.items(): | |
| target_dict[f'reliability_prob_class{cls_val}'] = [round(float(x), 4) for x in r_probs[:, idx]] | |
| B = len(predictions) | |
| target_dict['plm_mean_fp16'] = [emb_out['plm_mean_fp16'][i].clone() for i in range(B)] | |
| target_dict['qformer_feats_fp16'] = [emb_out['qformer_feats_fp16'][i].clone() for i in range(B)] | |
| target_dict['llm_last_fp16'] = [emb_out['llm_last_fp16'][i].clone() for i in range(B)] | |
| self.saved_dict_list.append(target_dict) | |
| def gather_dict_results(self, dict_list): | |
| if not dict_list: | |
| return [] | |
| list_of_dict_list = [None for _ in range(self.trainer.world_size)] | |
| dist.all_gather_object(list_of_dict_list, dict_list) | |
| dict_list = [i for ii in list_of_dict_list for i in ii] | |
| keys = dict_list[0].keys() | |
| gathered_dict = {} | |
| def _flatten_field(v): | |
| """Expand batch field to a list (handles scalar tolist() / single int / str).""" | |
| if isinstance(v, (list, tuple)): | |
| return list(v) | |
| return [v] | |
| for key in keys: | |
| gathered_dict[key] = [x for d in dict_list for x in _flatten_field(d[key])] | |
| dict_list = [] | |
| for i in range(len(gathered_dict['predictions'])): | |
| d = {k:gathered_dict[k][i] for k in keys} | |
| dict_list.append(d) | |
| return dict_list | |
| def load_ground_truth_go_from_test_set(self, result_list): | |
| """ | |
| Load ground truth GO terms from test set based on indices. | |
| Args: | |
| result_list: List of dicts with 'indices' field | |
| Returns: | |
| dict: Mapping from index to GO terms list | |
| """ | |
| if not self.test_set_path or not os.path.exists(self.test_set_path): | |
| print(f"[Warning] Test set path not provided or not found: {self.test_set_path}") | |
| return {} | |
| # Load all GO terms from test set | |
| go_dict = {} | |
| with open(self.test_set_path, 'r', encoding='utf-8') as f: | |
| for idx, line in enumerate(f): | |
| line = line.strip() | |
| if not line: | |
| continue | |
| try: | |
| row = json.loads(line) | |
| if len(row) >= 4: | |
| last_col = row[3] # GO terms column | |
| if isinstance(last_col, str) and last_col.startswith('['): | |
| try: | |
| go_terms = ast.literal_eval(last_col) | |
| except Exception: | |
| go_terms = [] | |
| elif isinstance(last_col, list): | |
| go_terms = last_col | |
| else: | |
| go_terms = [] | |
| go_dict[idx] = go_terms | |
| except Exception as e: | |
| print(f"[Warning] Error parsing line {idx}: {e}") | |
| go_dict[idx] = [] | |
| return go_dict | |
| def load_ground_truth_text_from_valid_set(self): | |
| """Load target text from valid_set.json. Returns dict index -> text (row[1]).""" | |
| if not self.valid_set_path or not os.path.exists(self.valid_set_path): | |
| return {} | |
| text_dict = {} | |
| with open(self.valid_set_path, 'r', encoding='utf-8') as f: | |
| for idx, line in enumerate(f): | |
| line = line.strip() | |
| if not line: | |
| continue | |
| try: | |
| row = json.loads(line) | |
| if len(row) >= 2: | |
| text_dict[idx] = str(row[1]).strip() | |
| else: | |
| text_dict[idx] = '' | |
| except Exception: | |
| text_dict[idx] = '' | |
| return text_dict | |
| def load_ground_truth_go_from_valid_set(self): | |
| """Load GO terms from valid_set.json (same format as test set). Returns dict index -> list of GO terms.""" | |
| if not self.valid_set_path or not os.path.exists(self.valid_set_path): | |
| return {} | |
| go_dict = {} | |
| with open(self.valid_set_path, 'r', encoding='utf-8') as f: | |
| for idx, line in enumerate(f): | |
| line = line.strip() | |
| if not line: | |
| continue | |
| try: | |
| row = json.loads(line) | |
| if len(row) >= 4: | |
| last_col = row[3] | |
| if isinstance(last_col, str) and last_col.startswith('['): | |
| try: | |
| go_terms = ast.literal_eval(last_col) | |
| except Exception: | |
| go_terms = [] | |
| elif isinstance(last_col, list): | |
| go_terms = last_col | |
| else: | |
| go_terms = [] | |
| go_dict[idx] = go_terms | |
| except Exception: | |
| go_dict[idx] = [] | |
| return go_dict | |
| def save_results(self, dict_list, log_prefix=""): | |
| if log_prefix: | |
| name = f'{log_prefix}_predictions.txt' | |
| else: | |
| name = 'predictions.txt' | |
| with open(os.path.join(self.logger.log_dir, name), 'w', encoding='utf8') as f: | |
| for d in dict_list: | |
| f.write(json.dumps(d, ensure_ascii=True,default=_json_default) + '\n') | |
| def on_validation_epoch_end(self): | |
| if getattr(self.trainer, "sanity_checking", False): | |
| return | |
| # Validation set BLEU/ROUGE: compute every epoch (val_go_gathered is always collected) | |
| val_go_gathered = self.gather_dict_results(self.val_saved_list_for_go) if self.val_saved_list_for_go else [] | |
| if val_go_gathered: | |
| val_sorted = sorted(val_go_gathered, key=lambda d: d['indices'] if isinstance(d['indices'], (int, float)) else d['indices'][0]) | |
| val_indices = [d['indices'] for d in val_sorted] | |
| val_predictions = [d['predictions'] for d in val_sorted] | |
| text_dict_val = self.load_ground_truth_text_from_valid_set() | |
| val_targets = [text_dict_val.get(i, '') for i in val_indices] | |
| if val_targets and any(t for t in val_targets): | |
| bleu2_val, bleu4_val, rouge_1_val, rouge_2_val, rouge_l_val, meteor_val = caption_evaluate( | |
| val_predictions, val_targets, self.blip2.llm_tokenizer, self.max_inference_len, | |
| verbose=(self.global_rank == 0)) | |
| acc_val = evaluate_exact_match(val_predictions, val_targets) | |
| self.log("val/acc", acc_val, sync_dist=False) | |
| self.log("val/bleu2", bleu2_val, sync_dist=False) | |
| self.log("val/bleu4", bleu4_val, sync_dist=False) | |
| self.log("val/rouge_1", rouge_1_val, sync_dist=False) | |
| self.log("val/rouge_2", rouge_2_val, sync_dist=False) | |
| self.log("val/rouge_l", rouge_l_val, sync_dist=False) | |
| self.log("val/meteor_score", meteor_val, sync_dist=False) | |
| if self.global_rank == 0: | |
| print(f'[Validation set] BLEU-2: {bleu2_val:.2f} BLEU-4: {bleu4_val:.2f} ROUGE-L: {rouge_l_val:.2f}') | |
| # Reliability head training: report Pearson/Spearman correlation every epoch (val and train) | |
| if self.train_reliability_head_only and val_go_gathered and 'r_pred' in val_go_gathered[0]: | |
| def _compute_accuracy(gathered, dl_idx, log_prefix, display_name): | |
| subset = [d for d in gathered if d.get('dataloader_idx', 0) == dl_idx] | |
| if not subset: | |
| return | |
| all_r_pred, all_r_gt = [], [] | |
| for d in subset: | |
| rp, rg = d['r_pred'], d['r_gt'] | |
| if isinstance(rp, (list, tuple)): | |
| all_r_pred.extend(float(x) for x in rp) | |
| all_r_gt.extend(float(x) for x in rg) | |
| else: | |
| all_r_pred.append(float(rp)) | |
| all_r_gt.append(float(rg)) | |
| if not all_r_pred: | |
| return | |
| pred_arr = np.array(all_r_pred, dtype=np.float64) | |
| gt_arr = np.array(all_r_gt, dtype=np.float64) | |
| binary_mode = getattr(self.blip2, 'reliability_binary', False) | |
| if binary_mode: | |
| gt_pos = np.isclose(gt_arr, 1.0, atol=1e-4) | |
| pred_pos = np.isclose(pred_arr, 1.0, atol=1e-4) | |
| acc = float((gt_pos == pred_pos).mean()) | |
| self.log(f"{log_prefix}/reliability_accuracy", acc, sync_dist=False) | |
| f1_per_class = {} | |
| per_class_lines = [] | |
| for tag, gt_is, pred_is in [("pos", gt_pos, pred_pos), ("neg", ~gt_pos, ~pred_pos)]: | |
| tp = int((gt_is & pred_is).sum()) | |
| fp = int((~gt_is & pred_is).sum()) | |
| fn = int((gt_is & ~pred_is).sum()) | |
| n_cls = int(gt_is.sum()) | |
| recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0 | |
| precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0 | |
| f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0.0 | |
| self.log(f"{log_prefix}/reliability_{tag}_recall", recall, sync_dist=False) | |
| self.log(f"{log_prefix}/reliability_{tag}_precision", precision, sync_dist=False) | |
| self.log(f"{log_prefix}/reliability_{tag}_f1", f1, sync_dist=False) | |
| if n_cls > 0: | |
| f1_per_class[tag] = f1 | |
| per_class_lines.append(f"{tag}: P={precision:.3f} R={recall:.3f} F1={f1:.3f} n={n_cls}") | |
| macro_f1 = float(np.mean(list(f1_per_class.values()))) if f1_per_class else 0.0 | |
| self.log(f"{log_prefix}/reliability_macro_f1", macro_f1, sync_dist=False) | |
| if self.global_rank == 0: | |
| print(f'[{display_name}] Reliability (binary) accuracy: {acc:.4f} (n={len(pred_arr)}), pos-F1: {f1_per_class.get("pos", 0.0):.4f}, macro-F1: {macro_f1:.4f}') | |
| for line in per_class_lines: | |
| print(f' {line}') | |
| return | |
| acc = float(np.isclose(pred_arr, gt_arr, atol=1e-4).mean()) | |
| self.log(f"{log_prefix}/reliability_accuracy", acc, sync_dist=False) | |
| # Per-class precision/recall/F1 + macro-F1 over classes present in GT. | |
| class_values = [-1.0, 0.0, 0.5, 1.0] | |
| f1_per_class = {} | |
| per_class_lines = [] | |
| for cls_val in class_values: | |
| gt_is = np.isclose(gt_arr, cls_val, atol=1e-4) | |
| pred_is = np.isclose(pred_arr, cls_val, atol=1e-4) | |
| tp = int((gt_is & pred_is).sum()) | |
| fp = int((~gt_is & pred_is).sum()) | |
| fn = int((gt_is & ~pred_is).sum()) | |
| n_cls = int(gt_is.sum()) | |
| recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0 | |
| precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0 | |
| f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0.0 | |
| tag = f"class{cls_val}".replace('-', 'neg').replace('.', 'p') | |
| self.log(f"{log_prefix}/reliability_{tag}_recall", recall, sync_dist=False) | |
| self.log(f"{log_prefix}/reliability_{tag}_precision", precision, sync_dist=False) | |
| self.log(f"{log_prefix}/reliability_{tag}_f1", f1, sync_dist=False) | |
| if n_cls > 0: | |
| f1_per_class[cls_val] = f1 | |
| per_class_lines.append(f"{cls_val}: P={precision:.3f} R={recall:.3f} F1={f1:.3f} n={n_cls}") | |
| macro_f1 = float(np.mean(list(f1_per_class.values()))) if f1_per_class else 0.0 | |
| self.log(f"{log_prefix}/reliability_macro_f1", macro_f1, sync_dist=False) | |
| # Backward-compat: keep old class-1 metric names. | |
| cls1_f1 = f1_per_class.get(1.0, 0.0) | |
| self.log(f"{log_prefix}/reliability_class1_f1", cls1_f1, sync_dist=False) | |
| if self.global_rank == 0: | |
| print(f'[{display_name}] Reliability accuracy: {acc:.4f} (n={len(pred_arr)}), macro-F1: {macro_f1:.4f}') | |
| for line in per_class_lines: | |
| print(f' {line}') | |
| # Validation set (dataloader_idx 0) | |
| _compute_accuracy(val_go_gathered, 0, "val", "Validation set") | |
| # Training set (dataloader_idx 2) | |
| _compute_accuracy(val_go_gathered, 2, "train", "Training set") | |
| self.val_saved_list_for_go = [] | |
| if (self.current_epoch+1) % self.caption_eval_epoch != 0: | |
| return | |
| if self.save_predictions and hasattr(self, '_saved_even_list') and self._saved_even_list: | |
| even_list = self.gather_dict_results(self._saved_even_list) | |
| self._saved_even_list = [] | |
| if self.global_rank == 0: | |
| out_even = os.path.join(self.logger.log_dir, f"val_epoch_end_{self.current_epoch+1}.json") | |
| with open(out_even, "w", encoding="utf-8") as f: | |
| for d in even_list: | |
| f.write(json.dumps({ | |
| 'indices': d.get('indices'), | |
| 'predictions': d.get('predictions'), | |
| }, ensure_ascii=True) + "\n") | |
| # result_list is from test set only (saved_dict_list filled in validation_step when dataloader_idx==1) | |
| result_list = self.gather_dict_results(self.saved_dict_list) | |
| self.saved_dict_list = [] | |
| last_epoch = (self.current_epoch + 1) == self.trainer.max_epochs | |
| # val_go_gathered already gathered at top of this function (every epoch) | |
| if self.global_rank == 0: | |
| # Test/dataset eval: only on last epoch | |
| if last_epoch: | |
| print('Store the result.') | |
| result_list_sorted = sorted(result_list, key=lambda x: x.get('indices', float('inf'))) | |
| result_list = result_list_sorted | |
| run_name = getattr(self.args, 'filename', 'run') | |
| out_dir = os.path.join("saved_results", run_name) | |
| if result_list and 'plm_mean_fp16' in result_list[0]: | |
| os.makedirs(out_dir, exist_ok=True) | |
| plm_stack = torch.stack([d['plm_mean_fp16'] for d in result_list]) | |
| qformer_stack = torch.stack([d['qformer_feats_fp16'] for d in result_list]) | |
| torch.save(plm_stack, os.path.join(out_dir, f"plm_mean_fp16_epoch{self.current_epoch+1}.pt")) | |
| torch.save(qformer_stack, os.path.join(out_dir, f"qformer_feats_fp16_epoch{self.current_epoch+1}.pt")) | |
| if 'llm_last_fp16' in result_list[0]: | |
| llm_last_stack = torch.stack([d['llm_last_fp16'] for d in result_list]) | |
| torch.save(llm_last_stack, os.path.join(out_dir, f"llm_last_fp16_epoch{self.current_epoch+1}.pt")) | |
| saved_names = ['plm_mean_fp16', 'qformer_feats_fp16'] + (['llm_last_fp16'] if 'llm_last_fp16' in result_list[0] else []) | |
| print(f'[Test eval] Saved {", ".join(saved_names)} to {out_dir}') | |
| # print(f'result_list sample: {result_list[0] if result_list else "empty"}') | |
| all_predictions = [i['predictions'] for i in result_list] | |
| all_targets = [i['targets'] for i in result_list] | |
| all_confidences = [i['confidences'] for i in result_list] | |
| all_reliability = [i['reliability'] for i in result_list] | |
| all_indices = [i.get('indices', idx) for idx, i in enumerate(result_list)] | |
| ground_truth_go_dict = self.load_ground_truth_go_from_test_set(result_list) | |
| all_ground_truth_go = [ground_truth_go_dict.get(idx) for idx in all_indices] | |
| for idx, (result_idx, gt_go) in enumerate(zip(all_indices, all_ground_truth_go)): | |
| result_list[idx]['gt_go'] = gt_go | |
| if self.report_go_wang_on_test: | |
| print("Starting GO term extraction from test set predictions and references...") | |
| cache_key = "go_extraction" | |
| os.makedirs("saved_results", exist_ok=True) | |
| run_name = getattr(self.args, 'filename', 'run') | |
| prediction_file = os.path.join("saved_results", f"go_terms_from_predictions_{run_name}_epoch{self.current_epoch+1}.pkl") | |
| reference_file = os.path.join("saved_results", f"go_terms_from_references_epoch{self.current_epoch+1}.pkl") | |
| # Predictions: always extract (do not reuse cache), different models give different results | |
| extracted_go_terms = process_texts_with_api(all_predictions) | |
| # References: reuse cache from saved_results when available | |
| reference_go_terms = load_or_process(reference_file, all_targets, "reference", cache_key) | |
| # Filter all GO terms to molecular_function only using evals/tools/go_files.tsv | |
| mf_go_ids = set() | |
| if os.path.exists(self.go_files_tsv_path): | |
| mf_go_ids = load_mf_go_ids_from_tsv(self.go_files_tsv_path, 'molecular_function') | |
| print(f"Filtering GO terms to molecular_function only: {len(mf_go_ids)} MF terms from {self.go_files_tsv_path}") | |
| else: | |
| print(f"[Warning] go_files.tsv not found at {self.go_files_tsv_path}, skipping MF filter") | |
| gt_go_raw = all_ground_truth_go | |
| ref_go_raw = reference_go_terms | |
| pred_go_raw = extracted_go_terms | |
| if mf_go_ids: | |
| gt_go_raw = filter_go_terms_by_set(gt_go_raw, mf_go_ids) | |
| ref_go_raw = filter_go_terms_by_set(ref_go_raw, mf_go_ids) | |
| pred_go_raw = filter_go_terms_by_set(pred_go_raw, mf_go_ids) | |
| assert len(gt_go_raw) == len(pred_go_raw) == len(ref_go_raw) == len(result_list), ( | |
| f"Length mismatch: gt={len(gt_go_raw)} pred={len(pred_go_raw)} ref={len(ref_go_raw)} result_list={len(result_list)}" | |
| ) | |
| for idx in range(len(result_list)): | |
| result_list[idx]['gt_go'] = gt_go_raw[idx] | |
| result_list[idx]['go_terms_from_predictions'] = pred_go_raw[idx] | |
| result_list[idx]['go_terms_from_references'] = ref_go_raw[idx] | |
| with open(prediction_file, 'wb') as f: | |
| pickle.dump(pred_go_raw, f) | |
| if all_ground_truth_go: | |
| print("Computing ontology metrics (ground truth vs reference GO, ground truth vs prediction GO, MF only)...") | |
| try: | |
| assert len(gt_go_raw) == len(ref_go_raw) == len(pred_go_raw), ( | |
| f"GO list length mismatch: gt={len(gt_go_raw)} ref={len(ref_go_raw)} pred={len(pred_go_raw)}" | |
| ) | |
| gt_go = gt_go_raw | |
| ref_go = ref_go_raw | |
| pred_go = pred_go_raw | |
| ref_wang_similarity, _ = compute_wang_similarity(gt_go, ref_go) | |
| ref_ia_distance = compute_InfoAccretion_distance(gt_go, ref_go, ia_file=self.ia_path, k=2) | |
| ref_jaccard_similarity = compute_jaccard_similarity(gt_go, ref_go) | |
| pred_wang_similarity, _ = compute_wang_similarity(gt_go, pred_go) | |
| pred_ia_distance = compute_InfoAccretion_distance(gt_go, pred_go, ia_file=self.ia_path, k=2) | |
| pred_jaccard_similarity = compute_jaccard_similarity(gt_go, pred_go) | |
| self.log("dataset/go_wang_similarity_reference", ref_wang_similarity, sync_dist=False) | |
| self.log("dataset/go_ia_distance_reference", ref_ia_distance, sync_dist=False) | |
| self.log("dataset/go_jaccard_similarity_reference", ref_jaccard_similarity, sync_dist=False) | |
| self.log("dataset/go_wang_similarity_prediction", pred_wang_similarity, sync_dist=False) | |
| self.log("dataset/go_ia_distance_prediction", pred_ia_distance, sync_dist=False) | |
| self.log("dataset/go_jaccard_similarity_prediction", pred_jaccard_similarity, sync_dist=False) | |
| print(f'Reference vs GT: Wang {ref_wang_similarity:.4f} IA {ref_ia_distance:.4f} Jaccard {ref_jaccard_similarity:.4f}') | |
| print(f'Prediction vs GT: Wang {pred_wang_similarity:.4f} IA {pred_ia_distance:.4f} Jaccard {pred_jaccard_similarity:.4f}') | |
| except Exception as e: | |
| print(f"[Warning] Failed to compute ontology metrics: {e}") | |
| self.save_results(result_list, 'dataset') | |
| log_prefix = 'dataset' | |
| mean_confidences = _mean_conf(all_confidences) | |
| # Note: BLEU/ROUGE/Wang above are on the *test* set (dataloader_idx 1), not validation set. | |
| mean_reliability = _mean_conf(all_reliability) | |
| print('[Inference training subset (dataset)] BLEU/ROUGE/Meteor:') | |
| bleu2, bleu4, rouge_1, rouge_2, rouge_l, meteor_score = \ | |
| caption_evaluate(all_predictions, all_targets, self.blip2.llm_tokenizer, self.max_inference_len) | |
| acc = evaluate_exact_match(all_predictions, all_targets) | |
| self.log(f"{log_prefix}/acc", acc, sync_dist=False) | |
| self.log(f"{log_prefix}/bleu2", bleu2, sync_dist=False) | |
| self.log(f"{log_prefix}/bleu4", bleu4, sync_dist=False) | |
| self.log(f"{log_prefix}/rouge_1", rouge_1, sync_dist=False) | |
| self.log(f"{log_prefix}/rouge_2", rouge_2, sync_dist=False) | |
| self.log(f"{log_prefix}/rouge_l", rouge_l, sync_dist=False) | |
| self.log(f"{log_prefix}/meteor_score", meteor_score, sync_dist=False) | |
| self.log(f"{log_prefix}/avg_confidences", mean_confidences, sync_dist=False) | |
| self.log(f"{log_prefix}/avg_reliability", mean_reliability, sync_dist=False) | |
| # print('avg_confidences', mean_confidences) | |
| # print('mean_reliability', mean_reliability) | |
| # Validation set BLEU/ROUGE already computed every epoch at top of this function | |
| # Last-epoch only (and only when --report_go_wang_on_val): validation set GO extraction + Wang vs valid_set.json | |
| if self.report_go_wang_on_val and last_epoch and val_go_gathered: | |
| val_sorted = sorted(val_go_gathered, key=lambda d: d['indices'] if isinstance(d['indices'], (int, float)) else d['indices'][0]) | |
| val_indices = [d['indices'] for d in val_sorted] | |
| val_predictions = [d['predictions'] for d in val_sorted] | |
| go_dict_val = self.load_ground_truth_go_from_valid_set() | |
| gt_go_list = [go_dict_val.get(i, []) for i in val_indices] | |
| try: | |
| pred_go_list = process_texts_with_api(val_predictions) | |
| val_wang_mean, _ = compute_wang_similarity(gt_go_list, pred_go_list) | |
| self.log("val/go_wang_similarity", val_wang_mean, sync_dist=False) | |
| print(f'[Validation set] Last epoch GO Wang similarity (pred vs valid_set.json): {val_wang_mean:.4f}') | |
| except Exception as e: | |
| print(f'[Warning] Validation set GO Wang failed: {e}') | |
| self.val_saved_list_for_go = [] | |
| def training_step(self, batch, batch_idx): | |
| if self.scheduler: | |
| self.scheduler.step(self.trainer.current_epoch, self.trainer.global_step) | |
| batch_size = batch[1].input_ids.size(0) | |
| blip_batch = batch[:4] if self.train_reliability_head_only else batch[:-1] | |
| loss, r_loss = self.blip2(blip_batch, return_pred=False) | |
| self.log("loss", loss, sync_dist=True, batch_size=batch_size) | |
| self.log("reliability_loss", r_loss, batch_size=batch_size, sync_dist=True) | |
| self.log("lr", self.trainer.optimizers[0].param_groups[0]['lr'], batch_size=batch_size, sync_dist=True) | |
| # Either full-dataset generation loss or subset reliability loss | |
| if self.train_reliability_head_only: | |
| return r_loss | |
| return loss | |
| def add_model_specific_args(parent_parser): | |
| parser = parent_parser.add_argument_group("ProtBlip2") | |
| # train mode | |
| parser.add_argument('--save_every_n_epochs', type=int, default=0) | |
| # Bert | |
| parser.add_argument('--bert_name', type=str, default='microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract') | |
| parser.add_argument('--cross_attention_freq', type=int, default=2) | |
| parser.add_argument('--num_query_token', type=int, default=8) | |
| # OPT | |
| parser.add_argument('--llm_name', type=str, default="facebook/galactica-1.3b") | |
| parser.add_argument('--num_beams', type=int, default=3) | |
| parser.add_argument('--do_sample', action='store_true', default=False) | |
| parser.add_argument('--max_inference_len', type=int, default=256) | |
| parser.add_argument('--min_inference_len', type=int, default=1) | |
| parser.add_argument('--llm_tune', type=str, default='freeze') | |
| parser.add_argument('--peft_config', type=str, default='') | |
| parser.add_argument('--peft_dir', type=str, default='') | |
| ## plm model | |
| parser.add_argument('--plm_model', type=str, default='facebook/esm2_t30_150M_UR50D') | |
| parser.add_argument('--plm_tune', type=str, default='freeze') | |
| parser.add_argument('--plm_lora_r', type=int, default=8) | |
| parser.add_argument('--plm_lora_alpha', type=int, default=8) | |
| parser.add_argument('--plm_lora_dropout', type=int, default=0.1) | |
| ## lora config | |
| parser.add_argument('--lora_r', type=int, default=16) | |
| parser.add_argument('--lora_alpha', type=int, default=16) | |
| parser.add_argument('--lora_dropout', type=int, default=0.1) | |
| parser.add_argument('--enbale_gradient_checkpointing', action='store_true', default=False) | |
| # optimization | |
| parser.add_argument('--weight_decay', type=float, default=0.05, help='optimizer weight decay') | |
| parser.add_argument('--init_lr', type=float, default=1e-4, help='optimizer init learning rate') | |
| parser.add_argument('--min_lr', type=float, default=1e-5, help='optimizer min learning rate') | |
| parser.add_argument('--warmup_lr', type=float, default=1e-6, help='optimizer warmup learning rate') | |
| parser.add_argument('--warmup_steps', type=int, default=1000, help='optimizer warmup steps') | |
| parser.add_argument('--lr_decay_rate', type=float, default=0.9, help='optimizer lr decay rate') | |
| parser.add_argument('--scheduler', type=str, default='linear_warmup_cosine_lr', help='type of scheduler') | |
| # Reliability head specific optimization parameters | |
| parser.add_argument('--reliability_lr', type=float, default=1e-4, help='learning rate for reliability head (if None, uses init_lr)') | |
| parser.add_argument('--stage1_path', type=str, default='') | |
| parser.add_argument('--stage2_path', type=str, default='') | |
| parser.add_argument('--init_checkpoint', type=str, default='') | |
| parser.add_argument('--caption_eval_epoch', type=int, default=10) | |
| parser.add_argument('--save_predictions', action='store_true', default=False, | |
| help='Save training and validation predictions to JSON files') | |
| # Encoder selection (automatically inferred from plm_model but can be explicit) | |
| parser.add_argument('--encoder_type', type=str, default='auto', choices=['auto', 'esm2', 'esmc'], help='Protein encoder type: auto (infer from plm_model), esm2 (HuggingFace ESM2), or esmc (official ESM-C package)') | |
| # GO term extraction parameters | |
| parser.add_argument('--report_go_wang_on_test', action='store_true', default=False, help='On test set (dataloader_idx 1): extract GO from predictions, compute Wang/IA/Jaccard, store extracted_go_terms per row') | |
| parser.add_argument('--ia_path', type=str, default='evals/tools/IA.txt', help='Path to Information Accretion (IA) file') | |
| parser.add_argument('--report_go_wang_on_val', action='store_true', default=False, help='On last epoch only: extract GO from val predictions (process_texts_with_api) and compute Wang vs valid_set.json') | |
| parser.add_argument('--go_files_tsv_path', type=str, default='evals/tools/go_files.tsv', help='Path to go_files.tsv (go_id, aspect) to filter GO terms to molecular_function only') | |
| return parent_parser | |
| def evaluate_exact_match(predictions, targets): | |
| acc = 0 | |
| for prediction, target in zip(predictions, targets): | |
| if prediction.strip() == target.strip(): | |
| acc += 1 | |
| acc = round(acc / len(predictions) * 100, 2) | |
| return acc |