ProtTale-demo / model /blip2_stage2.py
Mulah's picture
Sync new model/ code (binary reliability head)
431e05a
import os
import torch
from tqdm import tqdm
from model.blip2_opt import Blip2OPT
import pytorch_lightning as pl
from torch import optim
from lavis.common.optims import LinearWarmupCosineLRScheduler, LinearWarmupStepLRScheduler
import json
import ast
import pickle
from evals.tools.InfoAccretion import compute_InfoAccretion_distance
from evals.tools.wang_similarity import compute_wang_similarity
from evals.tools.jaccard_similarity import compute_jaccard_similarity
from evals.tools.extraction import process_texts_with_api
import numpy as np
import torch.distributed as dist
from typing import Any, Dict
from model.help_funcs import (
caption_evaluate,
AttrDict,
_mean_conf,
_json_default,
load_or_process,
load_mf_go_ids_from_tsv,
filter_go_terms_by_set,
build_joint_nonempty_mask,
filter_parallel_by_mask,
)
def _batch_to_device(x, device):
"""Move batch (dict/tensor/list) to device; skip non-tensor values (e.g. lists of ints from ESM-C)."""
if torch.is_tensor(x):
return x.to(device)
if isinstance(x, dict):
return {k: _batch_to_device(v, device) for k, v in x.items()}
if isinstance(x, (list, tuple)):
return type(x)(_batch_to_device(v, device) if torch.is_tensor(v) else v for v in x)
return x
class Blip2Stage2(pl.LightningModule):
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
pass
def __init__(self, args):
super().__init__()
if isinstance(args, dict):
args = AttrDict(**args)
self.args = args
self.caption_eval_epoch = args.caption_eval_epoch
self.do_sample = args.do_sample
self.num_beams = args.num_beams
self.max_inference_len = args.max_inference_len
self.min_inference_len = args.min_inference_len
self.llm_tune = args.llm_tune
# GO term extraction on test set (dataloader_idx 1)
self.report_go_wang_on_test = getattr(args, 'report_go_wang_on_test', False)
self.ia_path = getattr(args, 'ia_path', 'evals/tools/IA.txt')
self.test_set_path = getattr(args, 'test_set_path', '') or os.path.join(getattr(args, 'root', 'data/SwissProtV3'), 'test_set.json')
self.valid_set_path = getattr(args, 'valid_set_path', '') or os.path.join(getattr(args, 'root', 'data/SwissProtV3'), 'valid_set.json')
self.go_files_tsv_path = getattr(args, 'go_files_tsv_path', 'evals/tools/go_files.tsv')
# On last epoch: extract GO from val predictions and compute Wang vs valid_set.json (default off)
self.report_go_wang_on_val = getattr(args, 'report_go_wang_on_val', False)
# Prediction collection and saving parameters
self.save_predictions = getattr(args, 'save_predictions', False)
self.inference_on_training_data = getattr(args, 'inference_on_training_data', False)
self.train_reliability_head_only = getattr(args, 'train_reliability_head_only', False)
# Validate encoder_type and plm_model consistency
encoder_type = getattr(args, 'encoder_type', 'auto')
if encoder_type != 'auto':
if encoder_type == 'esm2' and not args.plm_model.startswith('facebook/esm2'):
raise ValueError(f"encoder_type='{encoder_type}' but plm_model='{args.plm_model}' does not start with 'facebook/esm2'")
elif encoder_type == 'esmc' and not args.plm_model.startswith('esmc_'):
raise ValueError(f"encoder_type='{encoder_type}' but plm_model='{args.plm_model}' does not start with 'esmc_'")
if args.llm_name.find('galactica') >= 0:
self.blip2 = Blip2OPT(args.bert_name,
args.num_query_token,
args.cross_attention_freq,
args.plm_model,
args.plm_tune,
args.llm_name,
args.llm_tune,
args.peft_dir,
args)
else:
raise NotImplementedError()
# Default training: freeze ESM (base + LoRA), reliability_head, ln_layer, Qformer; train LLM (decoder) only
if not self.inference_on_training_data and not self.train_reliability_head_only:
for name, param in self.blip2.named_parameters():
if 'reliability_head' in name or 'plm' in name:
param.requires_grad = False
self.save_hyperparameters(args)
def load_from_stage1_checkpoint(self, path):
ckpt = torch.load(path, map_location='cpu')
state_dict = ckpt['state_dict']
state_dict = {k.split('blip2qformer.')[1]:v for k, v in state_dict.items()}
self.blip2.load_state_dict(state_dict, strict=False)
return self
def freeze_for_reliability_finetune(self):
"""Freeze all parameters except reliability_head. Call after loading checkpoint for train_reliability_head_only."""
for name, param in self.named_parameters():
param.requires_grad = 'reliability_head' in name
@torch.no_grad()
def run_inference_on_training_subset(
self,
dm,
output_path,
min_go=2,
sample_size=2000,
seed=42,
reliability_label_zero=False,
):
"""
Run inference on training subset (>= min_go GO terms, up to sample_size).
If reliability_label_zero: write r=0 for all rows (no GO extraction).
Else: extract GO from predictions, compute Wang similarity, replace r by that score.
Save rows to output_path.
"""
dataloader = dm.get_inference_training_dataloader(min_go=min_go, sample_size=sample_size, seed=seed)
self.eval()
device = next(self.parameters()).device
if device.type == 'cpu' and torch.cuda.is_available():
self.to('cuda')
device = next(self.parameters()).device
idx_to_pred = []
n_batches = len(dataloader)
print(f"Inference on training subset: {n_batches} batches (sample_size={sample_size})", flush=True)
for batch in tqdm(dataloader, total=n_batches, desc="train_inference", unit="batch"):
prot_tokens, prompt_tokens, r_tensor, target_dict = batch
prot_tokens = _batch_to_device(prot_tokens, device)
if hasattr(prompt_tokens, 'to'):
prompt_tokens = prompt_tokens.to(device)
else:
prompt_tokens = type(prompt_tokens)({k: v.to(device) for k, v in prompt_tokens.items()})
r_tensor = r_tensor.to(device)
samples = {'prot_batch': prot_tokens, 'prompt_batch': prompt_tokens, 'reliability': r_tensor}
pred_texts, _, _, _, _ = self.blip2.generate(
samples,
do_sample=self.do_sample,
num_beams=self.num_beams,
max_length=self.max_inference_len,
min_length=self.min_inference_len,
)
indices = target_dict['indices']
if hasattr(indices, 'tolist'):
indices = indices.tolist()
for i, idx in enumerate(indices):
idx_to_pred.append((idx, pred_texts[i]))
idx_to_pred.sort(key=lambda x: x[0])
sorted_indices = [x[0] for x in idx_to_pred]
pred_texts_ordered = [x[1] for x in idx_to_pred]
with open(dm.train_dataset.data_path, 'r', encoding='utf-8') as f:
train_lines = [line.strip() for line in f if line.strip()]
train_rows = [json.loads(line) for line in train_lines]
if reliability_label_zero:
per_scores = [0.0] * len(sorted_indices)
else:
gt_go_list = []
for idx in sorted_indices:
row = train_rows[idx]
g = row[3]
go_list = ast.literal_eval(g) if isinstance(g, str) else g
gt_go_list.append(go_list)
print(f"Extracting GO terms from {len(pred_texts_ordered)} predictions (API calls)...", flush=True)
predicted_go_terms = process_texts_with_api(pred_texts_ordered)
_, per_scores = compute_wang_similarity(gt_go_list, predicted_go_terms)
os.makedirs(os.path.dirname(output_path) or '.', exist_ok=True)
with open(output_path, 'w', encoding='utf-8') as f:
for i, idx in enumerate(sorted_indices):
row = list(train_rows[idx])
row[1] = pred_texts_ordered[i] # predicted text
row[2] = per_scores[i]
f.write(json.dumps(row, ensure_ascii=True) + '\n')
r_desc = "r=0" if reliability_label_zero else "r (wang score)"
print(f"Saved {len(sorted_indices)} rows with {r_desc} to {output_path}", flush=True)
return output_path
@torch.no_grad()
def run_inference_on_validation_set(self, dm, output_path, reliability_label_zero=False):
"""
Run inference on full validation set.
If reliability_label_zero: write r=0 for all rows (no GO extraction).
Else: extract GO from predictions, compute Wang similarity, replace r by that score.
Save rows to output_path (same format as train).
"""
dataloader = dm.get_validation_inference_dataloader()
valid_path = getattr(dm, 'valid_set_path', None)
if not valid_path or not os.path.exists(valid_path):
raise FileNotFoundError(f"Validation set path not found: {valid_path}")
self.eval()
device = next(self.parameters()).device
if device.type == 'cpu' and torch.cuda.is_available():
self.to('cuda')
device = next(self.parameters()).device
idx_to_pred = []
n_batches = len(dataloader)
print(f"Inference on validation set: {n_batches} batches", flush=True)
for batch in tqdm(dataloader, total=n_batches, desc="val_inference", unit="batch"):
prot_tokens, prompt_tokens, r_tensor, target_dict = batch
prot_tokens = _batch_to_device(prot_tokens, device)
if hasattr(prompt_tokens, 'to'):
prompt_tokens = prompt_tokens.to(device)
else:
prompt_tokens = type(prompt_tokens)({k: v.to(device) for k, v in prompt_tokens.items()})
r_tensor = r_tensor.to(device)
samples = {'prot_batch': prot_tokens, 'prompt_batch': prompt_tokens, 'reliability': r_tensor}
pred_texts, _, _, _, _ = self.blip2.generate(
samples,
do_sample=self.do_sample,
num_beams=self.num_beams,
max_length=self.max_inference_len,
min_length=self.min_inference_len,
)
indices = target_dict['indices']
if hasattr(indices, 'tolist'):
indices = indices.tolist()
for i, idx in enumerate(indices):
idx_to_pred.append((idx, pred_texts[i]))
idx_to_pred.sort(key=lambda x: x[0])
sorted_indices = [x[0] for x in idx_to_pred]
pred_texts_ordered = [x[1] for x in idx_to_pred]
with open(valid_path, 'r', encoding='utf-8') as f:
valid_lines = [line.strip() for line in f if line.strip()]
valid_rows = [json.loads(line) for line in valid_lines]
if reliability_label_zero:
per_scores = [0.0] * len(sorted_indices)
else:
gt_go_list = []
for idx in sorted_indices:
row = valid_rows[idx]
g = row[3]
go_list = ast.literal_eval(g) if isinstance(g, str) else g
gt_go_list.append(go_list)
print(f"Extracting GO terms from {len(pred_texts_ordered)} predictions (API calls)...", flush=True)
predicted_go_terms = process_texts_with_api(pred_texts_ordered)
_, per_scores = compute_wang_similarity(gt_go_list, predicted_go_terms)
os.makedirs(os.path.dirname(output_path) or '.', exist_ok=True)
with open(output_path, 'w', encoding='utf-8') as f:
for i, idx in enumerate(sorted_indices):
row = list(valid_rows[idx])
row[1] = pred_texts_ordered[i] # predicted text
row[2] = per_scores[i]
f.write(json.dumps(row, ensure_ascii=True) + '\n')
r_desc = "r=0" if reliability_label_zero else "r (wang score)"
print(f"Saved {len(sorted_indices)} validation rows with {r_desc} to {output_path}", flush=True)
return output_path
def configure_optimizers(self):
self.trainer.fit_loop.setup_data()
warmup_steps = min(len(self.trainer.train_dataloader), self.args.warmup_steps)
if self.train_reliability_head_only:
reliability_params = [p for n, p in self.named_parameters() if 'reliability_head' in n and p.requires_grad]
reliability_lr = self.args.reliability_lr if self.args.reliability_lr is not None else self.args.init_lr
optimizer = optim.AdamW(reliability_params, lr=reliability_lr, weight_decay=self.args.weight_decay)
else:
main_params = []
reliability_params = []
for name, param in self.named_parameters():
if not param.requires_grad:
continue
if 'reliability_head' in name:
reliability_params.append(param)
else:
main_params.append(param)
reliability_lr = self.args.reliability_lr if self.args.reliability_lr is not None else self.args.init_lr
optimizer = optim.AdamW([
{'params': main_params, 'lr': self.args.init_lr, 'weight_decay': self.args.weight_decay},
{'params': reliability_params, 'lr': reliability_lr, 'weight_decay': self.args.weight_decay}
])
if self.args.scheduler == 'linear_warmup_cosine_lr':
self.scheduler = LinearWarmupCosineLRScheduler(optimizer, self.args.max_epochs, self.args.min_lr, self.args.init_lr, warmup_steps, self.args.warmup_lr)
elif self.args.scheduler == 'linear_warmup_step_lr':
self.scheduler = LinearWarmupStepLRScheduler(optimizer, self.args.max_epochs, self.args.min_lr, self.args.init_lr, self.args.lr_decay_rate, self.args.warmup_lr, warmup_steps)
elif self.args.scheduler == 'None':
self.scheduler = None
else:
raise NotImplementedError()
return optimizer
def save_predictions(self, predictions, targets, q_types=None, log_prefix=''):
assert len(predictions) == len(targets)
if log_prefix:
name = f'{log_prefix}_predictions.txt'
else:
name = 'predictions.txt'
with open(os.path.join(self.logger.log_dir, name), 'w', encoding='utf8') as f:
if q_types is not None:
for p, t, q in zip(predictions, targets, q_types):
line = {'prediction': p, 'target': t, 'q_type': q}
f.write(json.dumps(line, ensure_ascii=True) + '\n')
else:
for p, t in zip(predictions, targets):
line = {'prediction': p, 'target': t}
f.write(json.dumps(line, ensure_ascii=True) + '\n')
def on_validation_epoch_start(self) -> None:
self.saved_dict_list = []
self.prediction_list0 = []
self.target_list0 = []
self.prediction_list1 = []
self.target_list1 = []
self._saved_even_list = []
self.val_saved_list_for_go = []
@torch.no_grad()
def validation_step(self, batch, batch_idx, dataloader_idx=0):
if (dataloader_idx % 2) == 0:
text_batch = batch[1]
batch_size = text_batch.input_ids.shape[0]
blip_batch = batch[:4]
loss, r_loss, pred_texts, r_pred = self.blip2(blip_batch, return_pred=True)
idx_list = batch[4].detach().cpu().tolist()
if not isinstance(idx_list, list):
idx_list = [idx_list]
saved_dict = {'indices': idx_list, 'predictions': pred_texts}
if self.train_reliability_head_only:
r_gt = batch[3]
r_pred_list = r_pred.cpu().tolist() if torch.is_tensor(r_pred) else list(r_pred)
r_gt_list = r_gt.cpu().tolist() if torch.is_tensor(r_gt) else list(r_gt)
if not isinstance(r_pred_list, list):
r_pred_list = [r_pred_list]
if not isinstance(r_gt_list, list):
r_gt_list = [r_gt_list]
saved_dict['r_pred'] = r_pred_list
saved_dict['r_gt'] = r_gt_list
saved_dict['dataloader_idx'] = [dataloader_idx] * len(r_pred_list)
self.val_saved_list_for_go.append(saved_dict)
self.log(f"dataloader{dataloader_idx}/val_loss", loss,
on_step=False, on_epoch=True, prog_bar=True,
sync_dist=True, batch_size=batch_size)
self.log(f"dataloader{dataloader_idx}/reliability_loss", r_loss,
on_step=False, on_epoch=True, prog_bar=False,
sync_dist=True, batch_size=batch_size)
if self.train_reliability_head_only:
self.log("val/reliability_loss", r_loss,
on_step=False, on_epoch=True, prog_bar=False,
sync_dist=True, batch_size=batch_size)
elif (dataloader_idx % 2) == 1:
# Test set: collect predictions for BLEU/ROUGE and (if --report_go_wang_on_test) GO extraction + Wang
if (self.current_epoch+1) % self.caption_eval_epoch != 0:
return
prot_batch, prompt_batch, r_tensor, target_dict = batch
samples = {'prot_batch': prot_batch, 'prompt_batch': prompt_batch, 'reliability': r_tensor}
predictions, r_pred, avg_conf, emb_out, r_probs = self.blip2.generate(
samples,
do_sample=self.do_sample,
num_beams=self.num_beams,
max_length=self.max_inference_len,
min_length=self.min_inference_len
)
target_dict['predictions'] = predictions
target_dict['confidences'] = [round(float(x), 4) for x in avg_conf]
target_dict['reliability'] = [round(float(x), 4) for x in r_pred]
if getattr(self.blip2, 'reliability_binary', False):
# r_probs shape [B, 2]: index 0 = negative, 1 = positive (r==1).
target_dict['reliability_prob_class1'] = [round(float(x), 4) for x in r_probs[:, 1]]
else:
from model.blip2_opt import RELIABILITY_VAL_TO_IDX
for cls_val, idx in RELIABILITY_VAL_TO_IDX.items():
target_dict[f'reliability_prob_class{cls_val}'] = [round(float(x), 4) for x in r_probs[:, idx]]
B = len(predictions)
target_dict['plm_mean_fp16'] = [emb_out['plm_mean_fp16'][i].clone() for i in range(B)]
target_dict['qformer_feats_fp16'] = [emb_out['qformer_feats_fp16'][i].clone() for i in range(B)]
target_dict['llm_last_fp16'] = [emb_out['llm_last_fp16'][i].clone() for i in range(B)]
self.saved_dict_list.append(target_dict)
def gather_dict_results(self, dict_list):
if not dict_list:
return []
list_of_dict_list = [None for _ in range(self.trainer.world_size)]
dist.all_gather_object(list_of_dict_list, dict_list)
dict_list = [i for ii in list_of_dict_list for i in ii]
keys = dict_list[0].keys()
gathered_dict = {}
def _flatten_field(v):
"""Expand batch field to a list (handles scalar tolist() / single int / str)."""
if isinstance(v, (list, tuple)):
return list(v)
return [v]
for key in keys:
gathered_dict[key] = [x for d in dict_list for x in _flatten_field(d[key])]
dict_list = []
for i in range(len(gathered_dict['predictions'])):
d = {k:gathered_dict[k][i] for k in keys}
dict_list.append(d)
return dict_list
def load_ground_truth_go_from_test_set(self, result_list):
"""
Load ground truth GO terms from test set based on indices.
Args:
result_list: List of dicts with 'indices' field
Returns:
dict: Mapping from index to GO terms list
"""
if not self.test_set_path or not os.path.exists(self.test_set_path):
print(f"[Warning] Test set path not provided or not found: {self.test_set_path}")
return {}
# Load all GO terms from test set
go_dict = {}
with open(self.test_set_path, 'r', encoding='utf-8') as f:
for idx, line in enumerate(f):
line = line.strip()
if not line:
continue
try:
row = json.loads(line)
if len(row) >= 4:
last_col = row[3] # GO terms column
if isinstance(last_col, str) and last_col.startswith('['):
try:
go_terms = ast.literal_eval(last_col)
except Exception:
go_terms = []
elif isinstance(last_col, list):
go_terms = last_col
else:
go_terms = []
go_dict[idx] = go_terms
except Exception as e:
print(f"[Warning] Error parsing line {idx}: {e}")
go_dict[idx] = []
return go_dict
def load_ground_truth_text_from_valid_set(self):
"""Load target text from valid_set.json. Returns dict index -> text (row[1])."""
if not self.valid_set_path or not os.path.exists(self.valid_set_path):
return {}
text_dict = {}
with open(self.valid_set_path, 'r', encoding='utf-8') as f:
for idx, line in enumerate(f):
line = line.strip()
if not line:
continue
try:
row = json.loads(line)
if len(row) >= 2:
text_dict[idx] = str(row[1]).strip()
else:
text_dict[idx] = ''
except Exception:
text_dict[idx] = ''
return text_dict
def load_ground_truth_go_from_valid_set(self):
"""Load GO terms from valid_set.json (same format as test set). Returns dict index -> list of GO terms."""
if not self.valid_set_path or not os.path.exists(self.valid_set_path):
return {}
go_dict = {}
with open(self.valid_set_path, 'r', encoding='utf-8') as f:
for idx, line in enumerate(f):
line = line.strip()
if not line:
continue
try:
row = json.loads(line)
if len(row) >= 4:
last_col = row[3]
if isinstance(last_col, str) and last_col.startswith('['):
try:
go_terms = ast.literal_eval(last_col)
except Exception:
go_terms = []
elif isinstance(last_col, list):
go_terms = last_col
else:
go_terms = []
go_dict[idx] = go_terms
except Exception:
go_dict[idx] = []
return go_dict
def save_results(self, dict_list, log_prefix=""):
if log_prefix:
name = f'{log_prefix}_predictions.txt'
else:
name = 'predictions.txt'
with open(os.path.join(self.logger.log_dir, name), 'w', encoding='utf8') as f:
for d in dict_list:
f.write(json.dumps(d, ensure_ascii=True,default=_json_default) + '\n')
def on_validation_epoch_end(self):
if getattr(self.trainer, "sanity_checking", False):
return
# Validation set BLEU/ROUGE: compute every epoch (val_go_gathered is always collected)
val_go_gathered = self.gather_dict_results(self.val_saved_list_for_go) if self.val_saved_list_for_go else []
if val_go_gathered:
val_sorted = sorted(val_go_gathered, key=lambda d: d['indices'] if isinstance(d['indices'], (int, float)) else d['indices'][0])
val_indices = [d['indices'] for d in val_sorted]
val_predictions = [d['predictions'] for d in val_sorted]
text_dict_val = self.load_ground_truth_text_from_valid_set()
val_targets = [text_dict_val.get(i, '') for i in val_indices]
if val_targets and any(t for t in val_targets):
bleu2_val, bleu4_val, rouge_1_val, rouge_2_val, rouge_l_val, meteor_val = caption_evaluate(
val_predictions, val_targets, self.blip2.llm_tokenizer, self.max_inference_len,
verbose=(self.global_rank == 0))
acc_val = evaluate_exact_match(val_predictions, val_targets)
self.log("val/acc", acc_val, sync_dist=False)
self.log("val/bleu2", bleu2_val, sync_dist=False)
self.log("val/bleu4", bleu4_val, sync_dist=False)
self.log("val/rouge_1", rouge_1_val, sync_dist=False)
self.log("val/rouge_2", rouge_2_val, sync_dist=False)
self.log("val/rouge_l", rouge_l_val, sync_dist=False)
self.log("val/meteor_score", meteor_val, sync_dist=False)
if self.global_rank == 0:
print(f'[Validation set] BLEU-2: {bleu2_val:.2f} BLEU-4: {bleu4_val:.2f} ROUGE-L: {rouge_l_val:.2f}')
# Reliability head training: report Pearson/Spearman correlation every epoch (val and train)
if self.train_reliability_head_only and val_go_gathered and 'r_pred' in val_go_gathered[0]:
def _compute_accuracy(gathered, dl_idx, log_prefix, display_name):
subset = [d for d in gathered if d.get('dataloader_idx', 0) == dl_idx]
if not subset:
return
all_r_pred, all_r_gt = [], []
for d in subset:
rp, rg = d['r_pred'], d['r_gt']
if isinstance(rp, (list, tuple)):
all_r_pred.extend(float(x) for x in rp)
all_r_gt.extend(float(x) for x in rg)
else:
all_r_pred.append(float(rp))
all_r_gt.append(float(rg))
if not all_r_pred:
return
pred_arr = np.array(all_r_pred, dtype=np.float64)
gt_arr = np.array(all_r_gt, dtype=np.float64)
binary_mode = getattr(self.blip2, 'reliability_binary', False)
if binary_mode:
gt_pos = np.isclose(gt_arr, 1.0, atol=1e-4)
pred_pos = np.isclose(pred_arr, 1.0, atol=1e-4)
acc = float((gt_pos == pred_pos).mean())
self.log(f"{log_prefix}/reliability_accuracy", acc, sync_dist=False)
f1_per_class = {}
per_class_lines = []
for tag, gt_is, pred_is in [("pos", gt_pos, pred_pos), ("neg", ~gt_pos, ~pred_pos)]:
tp = int((gt_is & pred_is).sum())
fp = int((~gt_is & pred_is).sum())
fn = int((gt_is & ~pred_is).sum())
n_cls = int(gt_is.sum())
recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0.0
self.log(f"{log_prefix}/reliability_{tag}_recall", recall, sync_dist=False)
self.log(f"{log_prefix}/reliability_{tag}_precision", precision, sync_dist=False)
self.log(f"{log_prefix}/reliability_{tag}_f1", f1, sync_dist=False)
if n_cls > 0:
f1_per_class[tag] = f1
per_class_lines.append(f"{tag}: P={precision:.3f} R={recall:.3f} F1={f1:.3f} n={n_cls}")
macro_f1 = float(np.mean(list(f1_per_class.values()))) if f1_per_class else 0.0
self.log(f"{log_prefix}/reliability_macro_f1", macro_f1, sync_dist=False)
if self.global_rank == 0:
print(f'[{display_name}] Reliability (binary) accuracy: {acc:.4f} (n={len(pred_arr)}), pos-F1: {f1_per_class.get("pos", 0.0):.4f}, macro-F1: {macro_f1:.4f}')
for line in per_class_lines:
print(f' {line}')
return
acc = float(np.isclose(pred_arr, gt_arr, atol=1e-4).mean())
self.log(f"{log_prefix}/reliability_accuracy", acc, sync_dist=False)
# Per-class precision/recall/F1 + macro-F1 over classes present in GT.
class_values = [-1.0, 0.0, 0.5, 1.0]
f1_per_class = {}
per_class_lines = []
for cls_val in class_values:
gt_is = np.isclose(gt_arr, cls_val, atol=1e-4)
pred_is = np.isclose(pred_arr, cls_val, atol=1e-4)
tp = int((gt_is & pred_is).sum())
fp = int((~gt_is & pred_is).sum())
fn = int((gt_is & ~pred_is).sum())
n_cls = int(gt_is.sum())
recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0.0
tag = f"class{cls_val}".replace('-', 'neg').replace('.', 'p')
self.log(f"{log_prefix}/reliability_{tag}_recall", recall, sync_dist=False)
self.log(f"{log_prefix}/reliability_{tag}_precision", precision, sync_dist=False)
self.log(f"{log_prefix}/reliability_{tag}_f1", f1, sync_dist=False)
if n_cls > 0:
f1_per_class[cls_val] = f1
per_class_lines.append(f"{cls_val}: P={precision:.3f} R={recall:.3f} F1={f1:.3f} n={n_cls}")
macro_f1 = float(np.mean(list(f1_per_class.values()))) if f1_per_class else 0.0
self.log(f"{log_prefix}/reliability_macro_f1", macro_f1, sync_dist=False)
# Backward-compat: keep old class-1 metric names.
cls1_f1 = f1_per_class.get(1.0, 0.0)
self.log(f"{log_prefix}/reliability_class1_f1", cls1_f1, sync_dist=False)
if self.global_rank == 0:
print(f'[{display_name}] Reliability accuracy: {acc:.4f} (n={len(pred_arr)}), macro-F1: {macro_f1:.4f}')
for line in per_class_lines:
print(f' {line}')
# Validation set (dataloader_idx 0)
_compute_accuracy(val_go_gathered, 0, "val", "Validation set")
# Training set (dataloader_idx 2)
_compute_accuracy(val_go_gathered, 2, "train", "Training set")
self.val_saved_list_for_go = []
if (self.current_epoch+1) % self.caption_eval_epoch != 0:
return
if self.save_predictions and hasattr(self, '_saved_even_list') and self._saved_even_list:
even_list = self.gather_dict_results(self._saved_even_list)
self._saved_even_list = []
if self.global_rank == 0:
out_even = os.path.join(self.logger.log_dir, f"val_epoch_end_{self.current_epoch+1}.json")
with open(out_even, "w", encoding="utf-8") as f:
for d in even_list:
f.write(json.dumps({
'indices': d.get('indices'),
'predictions': d.get('predictions'),
}, ensure_ascii=True) + "\n")
# result_list is from test set only (saved_dict_list filled in validation_step when dataloader_idx==1)
result_list = self.gather_dict_results(self.saved_dict_list)
self.saved_dict_list = []
last_epoch = (self.current_epoch + 1) == self.trainer.max_epochs
# val_go_gathered already gathered at top of this function (every epoch)
if self.global_rank == 0:
# Test/dataset eval: only on last epoch
if last_epoch:
print('Store the result.')
result_list_sorted = sorted(result_list, key=lambda x: x.get('indices', float('inf')))
result_list = result_list_sorted
run_name = getattr(self.args, 'filename', 'run')
out_dir = os.path.join("saved_results", run_name)
if result_list and 'plm_mean_fp16' in result_list[0]:
os.makedirs(out_dir, exist_ok=True)
plm_stack = torch.stack([d['plm_mean_fp16'] for d in result_list])
qformer_stack = torch.stack([d['qformer_feats_fp16'] for d in result_list])
torch.save(plm_stack, os.path.join(out_dir, f"plm_mean_fp16_epoch{self.current_epoch+1}.pt"))
torch.save(qformer_stack, os.path.join(out_dir, f"qformer_feats_fp16_epoch{self.current_epoch+1}.pt"))
if 'llm_last_fp16' in result_list[0]:
llm_last_stack = torch.stack([d['llm_last_fp16'] for d in result_list])
torch.save(llm_last_stack, os.path.join(out_dir, f"llm_last_fp16_epoch{self.current_epoch+1}.pt"))
saved_names = ['plm_mean_fp16', 'qformer_feats_fp16'] + (['llm_last_fp16'] if 'llm_last_fp16' in result_list[0] else [])
print(f'[Test eval] Saved {", ".join(saved_names)} to {out_dir}')
# print(f'result_list sample: {result_list[0] if result_list else "empty"}')
all_predictions = [i['predictions'] for i in result_list]
all_targets = [i['targets'] for i in result_list]
all_confidences = [i['confidences'] for i in result_list]
all_reliability = [i['reliability'] for i in result_list]
all_indices = [i.get('indices', idx) for idx, i in enumerate(result_list)]
ground_truth_go_dict = self.load_ground_truth_go_from_test_set(result_list)
all_ground_truth_go = [ground_truth_go_dict.get(idx) for idx in all_indices]
for idx, (result_idx, gt_go) in enumerate(zip(all_indices, all_ground_truth_go)):
result_list[idx]['gt_go'] = gt_go
if self.report_go_wang_on_test:
print("Starting GO term extraction from test set predictions and references...")
cache_key = "go_extraction"
os.makedirs("saved_results", exist_ok=True)
run_name = getattr(self.args, 'filename', 'run')
prediction_file = os.path.join("saved_results", f"go_terms_from_predictions_{run_name}_epoch{self.current_epoch+1}.pkl")
reference_file = os.path.join("saved_results", f"go_terms_from_references_epoch{self.current_epoch+1}.pkl")
# Predictions: always extract (do not reuse cache), different models give different results
extracted_go_terms = process_texts_with_api(all_predictions)
# References: reuse cache from saved_results when available
reference_go_terms = load_or_process(reference_file, all_targets, "reference", cache_key)
# Filter all GO terms to molecular_function only using evals/tools/go_files.tsv
mf_go_ids = set()
if os.path.exists(self.go_files_tsv_path):
mf_go_ids = load_mf_go_ids_from_tsv(self.go_files_tsv_path, 'molecular_function')
print(f"Filtering GO terms to molecular_function only: {len(mf_go_ids)} MF terms from {self.go_files_tsv_path}")
else:
print(f"[Warning] go_files.tsv not found at {self.go_files_tsv_path}, skipping MF filter")
gt_go_raw = all_ground_truth_go
ref_go_raw = reference_go_terms
pred_go_raw = extracted_go_terms
if mf_go_ids:
gt_go_raw = filter_go_terms_by_set(gt_go_raw, mf_go_ids)
ref_go_raw = filter_go_terms_by_set(ref_go_raw, mf_go_ids)
pred_go_raw = filter_go_terms_by_set(pred_go_raw, mf_go_ids)
assert len(gt_go_raw) == len(pred_go_raw) == len(ref_go_raw) == len(result_list), (
f"Length mismatch: gt={len(gt_go_raw)} pred={len(pred_go_raw)} ref={len(ref_go_raw)} result_list={len(result_list)}"
)
for idx in range(len(result_list)):
result_list[idx]['gt_go'] = gt_go_raw[idx]
result_list[idx]['go_terms_from_predictions'] = pred_go_raw[idx]
result_list[idx]['go_terms_from_references'] = ref_go_raw[idx]
with open(prediction_file, 'wb') as f:
pickle.dump(pred_go_raw, f)
if all_ground_truth_go:
print("Computing ontology metrics (ground truth vs reference GO, ground truth vs prediction GO, MF only)...")
try:
assert len(gt_go_raw) == len(ref_go_raw) == len(pred_go_raw), (
f"GO list length mismatch: gt={len(gt_go_raw)} ref={len(ref_go_raw)} pred={len(pred_go_raw)}"
)
gt_go = gt_go_raw
ref_go = ref_go_raw
pred_go = pred_go_raw
ref_wang_similarity, _ = compute_wang_similarity(gt_go, ref_go)
ref_ia_distance = compute_InfoAccretion_distance(gt_go, ref_go, ia_file=self.ia_path, k=2)
ref_jaccard_similarity = compute_jaccard_similarity(gt_go, ref_go)
pred_wang_similarity, _ = compute_wang_similarity(gt_go, pred_go)
pred_ia_distance = compute_InfoAccretion_distance(gt_go, pred_go, ia_file=self.ia_path, k=2)
pred_jaccard_similarity = compute_jaccard_similarity(gt_go, pred_go)
self.log("dataset/go_wang_similarity_reference", ref_wang_similarity, sync_dist=False)
self.log("dataset/go_ia_distance_reference", ref_ia_distance, sync_dist=False)
self.log("dataset/go_jaccard_similarity_reference", ref_jaccard_similarity, sync_dist=False)
self.log("dataset/go_wang_similarity_prediction", pred_wang_similarity, sync_dist=False)
self.log("dataset/go_ia_distance_prediction", pred_ia_distance, sync_dist=False)
self.log("dataset/go_jaccard_similarity_prediction", pred_jaccard_similarity, sync_dist=False)
print(f'Reference vs GT: Wang {ref_wang_similarity:.4f} IA {ref_ia_distance:.4f} Jaccard {ref_jaccard_similarity:.4f}')
print(f'Prediction vs GT: Wang {pred_wang_similarity:.4f} IA {pred_ia_distance:.4f} Jaccard {pred_jaccard_similarity:.4f}')
except Exception as e:
print(f"[Warning] Failed to compute ontology metrics: {e}")
self.save_results(result_list, 'dataset')
log_prefix = 'dataset'
mean_confidences = _mean_conf(all_confidences)
# Note: BLEU/ROUGE/Wang above are on the *test* set (dataloader_idx 1), not validation set.
mean_reliability = _mean_conf(all_reliability)
print('[Inference training subset (dataset)] BLEU/ROUGE/Meteor:')
bleu2, bleu4, rouge_1, rouge_2, rouge_l, meteor_score = \
caption_evaluate(all_predictions, all_targets, self.blip2.llm_tokenizer, self.max_inference_len)
acc = evaluate_exact_match(all_predictions, all_targets)
self.log(f"{log_prefix}/acc", acc, sync_dist=False)
self.log(f"{log_prefix}/bleu2", bleu2, sync_dist=False)
self.log(f"{log_prefix}/bleu4", bleu4, sync_dist=False)
self.log(f"{log_prefix}/rouge_1", rouge_1, sync_dist=False)
self.log(f"{log_prefix}/rouge_2", rouge_2, sync_dist=False)
self.log(f"{log_prefix}/rouge_l", rouge_l, sync_dist=False)
self.log(f"{log_prefix}/meteor_score", meteor_score, sync_dist=False)
self.log(f"{log_prefix}/avg_confidences", mean_confidences, sync_dist=False)
self.log(f"{log_prefix}/avg_reliability", mean_reliability, sync_dist=False)
# print('avg_confidences', mean_confidences)
# print('mean_reliability', mean_reliability)
# Validation set BLEU/ROUGE already computed every epoch at top of this function
# Last-epoch only (and only when --report_go_wang_on_val): validation set GO extraction + Wang vs valid_set.json
if self.report_go_wang_on_val and last_epoch and val_go_gathered:
val_sorted = sorted(val_go_gathered, key=lambda d: d['indices'] if isinstance(d['indices'], (int, float)) else d['indices'][0])
val_indices = [d['indices'] for d in val_sorted]
val_predictions = [d['predictions'] for d in val_sorted]
go_dict_val = self.load_ground_truth_go_from_valid_set()
gt_go_list = [go_dict_val.get(i, []) for i in val_indices]
try:
pred_go_list = process_texts_with_api(val_predictions)
val_wang_mean, _ = compute_wang_similarity(gt_go_list, pred_go_list)
self.log("val/go_wang_similarity", val_wang_mean, sync_dist=False)
print(f'[Validation set] Last epoch GO Wang similarity (pred vs valid_set.json): {val_wang_mean:.4f}')
except Exception as e:
print(f'[Warning] Validation set GO Wang failed: {e}')
self.val_saved_list_for_go = []
def training_step(self, batch, batch_idx):
if self.scheduler:
self.scheduler.step(self.trainer.current_epoch, self.trainer.global_step)
batch_size = batch[1].input_ids.size(0)
blip_batch = batch[:4] if self.train_reliability_head_only else batch[:-1]
loss, r_loss = self.blip2(blip_batch, return_pred=False)
self.log("loss", loss, sync_dist=True, batch_size=batch_size)
self.log("reliability_loss", r_loss, batch_size=batch_size, sync_dist=True)
self.log("lr", self.trainer.optimizers[0].param_groups[0]['lr'], batch_size=batch_size, sync_dist=True)
# Either full-dataset generation loss or subset reliability loss
if self.train_reliability_head_only:
return r_loss
return loss
@staticmethod
def add_model_specific_args(parent_parser):
parser = parent_parser.add_argument_group("ProtBlip2")
# train mode
parser.add_argument('--save_every_n_epochs', type=int, default=0)
# Bert
parser.add_argument('--bert_name', type=str, default='microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract')
parser.add_argument('--cross_attention_freq', type=int, default=2)
parser.add_argument('--num_query_token', type=int, default=8)
# OPT
parser.add_argument('--llm_name', type=str, default="facebook/galactica-1.3b")
parser.add_argument('--num_beams', type=int, default=3)
parser.add_argument('--do_sample', action='store_true', default=False)
parser.add_argument('--max_inference_len', type=int, default=256)
parser.add_argument('--min_inference_len', type=int, default=1)
parser.add_argument('--llm_tune', type=str, default='freeze')
parser.add_argument('--peft_config', type=str, default='')
parser.add_argument('--peft_dir', type=str, default='')
## plm model
parser.add_argument('--plm_model', type=str, default='facebook/esm2_t30_150M_UR50D')
parser.add_argument('--plm_tune', type=str, default='freeze')
parser.add_argument('--plm_lora_r', type=int, default=8)
parser.add_argument('--plm_lora_alpha', type=int, default=8)
parser.add_argument('--plm_lora_dropout', type=int, default=0.1)
## lora config
parser.add_argument('--lora_r', type=int, default=16)
parser.add_argument('--lora_alpha', type=int, default=16)
parser.add_argument('--lora_dropout', type=int, default=0.1)
parser.add_argument('--enbale_gradient_checkpointing', action='store_true', default=False)
# optimization
parser.add_argument('--weight_decay', type=float, default=0.05, help='optimizer weight decay')
parser.add_argument('--init_lr', type=float, default=1e-4, help='optimizer init learning rate')
parser.add_argument('--min_lr', type=float, default=1e-5, help='optimizer min learning rate')
parser.add_argument('--warmup_lr', type=float, default=1e-6, help='optimizer warmup learning rate')
parser.add_argument('--warmup_steps', type=int, default=1000, help='optimizer warmup steps')
parser.add_argument('--lr_decay_rate', type=float, default=0.9, help='optimizer lr decay rate')
parser.add_argument('--scheduler', type=str, default='linear_warmup_cosine_lr', help='type of scheduler')
# Reliability head specific optimization parameters
parser.add_argument('--reliability_lr', type=float, default=1e-4, help='learning rate for reliability head (if None, uses init_lr)')
parser.add_argument('--stage1_path', type=str, default='')
parser.add_argument('--stage2_path', type=str, default='')
parser.add_argument('--init_checkpoint', type=str, default='')
parser.add_argument('--caption_eval_epoch', type=int, default=10)
parser.add_argument('--save_predictions', action='store_true', default=False,
help='Save training and validation predictions to JSON files')
# Encoder selection (automatically inferred from plm_model but can be explicit)
parser.add_argument('--encoder_type', type=str, default='auto', choices=['auto', 'esm2', 'esmc'], help='Protein encoder type: auto (infer from plm_model), esm2 (HuggingFace ESM2), or esmc (official ESM-C package)')
# GO term extraction parameters
parser.add_argument('--report_go_wang_on_test', action='store_true', default=False, help='On test set (dataloader_idx 1): extract GO from predictions, compute Wang/IA/Jaccard, store extracted_go_terms per row')
parser.add_argument('--ia_path', type=str, default='evals/tools/IA.txt', help='Path to Information Accretion (IA) file')
parser.add_argument('--report_go_wang_on_val', action='store_true', default=False, help='On last epoch only: extract GO from val predictions (process_texts_with_api) and compute Wang vs valid_set.json')
parser.add_argument('--go_files_tsv_path', type=str, default='evals/tools/go_files.tsv', help='Path to go_files.tsv (go_id, aspect) to filter GO terms to molecular_function only')
return parent_parser
def evaluate_exact_match(predictions, targets):
acc = 0
for prediction, target in zip(predictions, targets):
if prediction.strip() == target.strip():
acc += 1
acc = round(acc / len(predictions) * 100, 2)
return acc