| """ |
| Evaluator for biological language models on synthetic sequence tasks. |
| Supports masked language models (ESM-2, NT) and autoregressive models. |
| """ |
|
|
| import re |
| import logging |
| from typing import List, Dict, Optional |
| import numpy as np |
| from transformers import AutoModelForMaskedLM, AutoTokenizer, EsmTokenizer |
| import torch |
| from difflib import SequenceMatcher |
|
|
| from .tasks import BioTask |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class BioEvaluator: |
| """Evaluates biological language models on sequence tasks.""" |
| |
| def __init__(self, device: str = "auto", max_length: int = 1024): |
| self.device = device if device != "auto" else ("cuda" if torch.cuda.is_available() else "cpu") |
| self.max_length = max_length |
| self._model_cache = {} |
| self._tokenizer_cache = {} |
| |
| def _load_model(self, model_path: str): |
| """Load model with caching.""" |
| if model_path not in self._model_cache: |
| logger.info(f"Loading model from {model_path}") |
| try: |
| model = AutoModelForMaskedLM.from_pretrained( |
| model_path, |
| torch_dtype=torch.bfloat16, |
| trust_remote_code=True, |
| ) |
| except: |
| |
| from transformers import AutoModel |
| model = AutoModel.from_pretrained( |
| model_path, |
| torch_dtype=torch.bfloat16, |
| trust_remote_code=True, |
| ) |
| |
| model = model.to(self.device) |
| model.eval() |
| self._model_cache[model_path] = model |
| |
| return self._model_cache[model_path] |
| |
| def _load_tokenizer(self, model_path: str): |
| """Load tokenizer with caching.""" |
| if model_path not in self._tokenizer_cache: |
| logger.info(f"Loading tokenizer from {model_path}") |
| tokenizer = AutoTokenizer.from_pretrained( |
| model_path, |
| trust_remote_code=True, |
| ) |
| self._tokenizer_cache[model_path] = tokenizer |
| |
| return self._tokenizer_cache[model_path] |
| |
| def evaluate_model( |
| self, |
| model_path: str, |
| tasks: List[BioTask], |
| ) -> Dict[str, float]: |
| """Evaluate a model on a list of tasks. Returns task_id -> score mapping.""" |
| model = self._load_model(model_path) |
| tokenizer = self._load_tokenizer(model_path) |
| |
| results = {} |
| |
| for task in tasks: |
| try: |
| score = self._evaluate_single_task(model, tokenizer, task) |
| results[task.task_id] = score |
| except Exception as e: |
| logger.error(f"Error evaluating task {task.task_id}: {e}") |
| results[task.task_id] = 0.0 |
| |
| return results |
| |
| def _evaluate_single_task( |
| self, |
| model: torch.nn.Module, |
| tokenizer, |
| task: BioTask, |
| ) -> float: |
| """Evaluate a single task.""" |
| |
| if task.evaluation_metric == "sequence_identity": |
| return self._eval_sequence_identity(model, tokenizer, task) |
| |
| elif task.evaluation_metric == "sequence_similarity": |
| return self._eval_sequence_similarity(model, tokenizer, task) |
| |
| elif task.evaluation_metric == "contains_substring": |
| return self._eval_contains_substring(model, tokenizer, task) |
| |
| elif task.evaluation_metric == "exact_match": |
| return self._eval_exact_match(model, tokenizer, task) |
| |
| elif task.evaluation_metric == "perplexity": |
| return self._eval_perplexity(model, tokenizer, task) |
| |
| elif task.evaluation_metric == "rna_structure_similarity": |
| return self._eval_rna_structure(model, tokenizer, task) |
| |
| else: |
| logger.warning(f"Unknown metric: {task.evaluation_metric}, defaulting to sequence similarity") |
| return self._eval_sequence_similarity(model, tokenizer, task) |
| |
| def _get_model_output(self, model, tokenizer, prompt: str) -> str: |
| """Get model output for a prompt.""" |
| |
| |
| |
| if task_has_mask := "<mask>" in prompt or "[MASK]" in prompt: |
| |
| return self._predict_masked(model, tokenizer, prompt) |
| else: |
| |
| return self._generate_sequence(model, tokenizer, prompt) |
| |
| def _predict_masked(self, model, tokenizer, prompt: str) -> str: |
| """Predict masked tokens in a sequence.""" |
| |
| tokens = tokenizer.tokenize(prompt) |
| |
| |
| mask_token = tokenizer.mask_token or "<mask>" |
| mask_positions = [i for i, t in enumerate(tokens) if t == mask_token or t == "[MASK]"] |
| |
| if not mask_positions: |
| |
| return prompt |
| |
| |
| input_ids = tokenizer.encode(prompt, return_tensors="pt", max_length=self.max_length, truncation=True) |
| input_ids = input_ids.to(self.device) |
| |
| |
| with torch.no_grad(): |
| outputs = model(input_ids) |
| logits = outputs.logits |
| |
| |
| predicted_tokens = tokens.copy() |
| for pos in mask_positions: |
| mask_logits = logits[0, pos + 1] |
| predicted_id = torch.argmax(mask_logits).item() |
| predicted_token = tokenizer.convert_ids_to_tokens([predicted_id])[0] |
| predicted_tokens[pos] = predicted_token |
| |
| |
| return tokenizer.convert_tokens_to_string(predicted_tokens) |
| |
| def _generate_sequence(self, model, tokenizer, prompt: str, max_new_tokens: int = 50) -> str: |
| """Generate a sequence continuation.""" |
| |
| |
| |
| input_ids = tokenizer.encode(prompt, return_tensors="pt", max_length=self.max_length, truncation=True) |
| input_ids = input_ids.to(self.device) |
| |
| generated = input_ids.clone() |
| |
| |
| for _ in range(max_new_tokens): |
| with torch.no_grad(): |
| outputs = model(generated) |
| logits = outputs.logits |
| |
| |
| next_token_logits = logits[0, -1, :] |
| next_token_id = torch.argmax(next_token_logits).item() |
| |
| |
| next_token = torch.tensor([[next_token_id]], device=self.device) |
| generated = torch.cat([generated, next_token], dim=1) |
| |
| |
| if next_token_id == tokenizer.eos_token_id: |
| break |
| |
| return tokenizer.decode(generated[0], skip_special_tokens=True) |
| |
| def _eval_sequence_identity(self, model, tokenizer, task: BioTask) -> float: |
| """Evaluate exact sequence identity.""" |
| prompt = task.prompt |
| if task.context: |
| prompt += f" {task.context}" |
| |
| output = self._get_model_output(model, tokenizer, prompt) |
| |
| if task.expected_answer is None: |
| return 0.5 |
| |
| |
| output_seq = self._extract_sequence(output, task.task_type) |
| expected = task.expected_answer.strip().upper() |
| |
| if not output_seq or not expected: |
| return 0.0 |
| |
| |
| matches = sum(1 for a, b in zip(output_seq, expected) if a == b) |
| length = max(len(output_seq), len(expected)) |
| |
| return matches / length if length > 0 else 0.0 |
| |
| def _eval_sequence_similarity(self, model, tokenizer, task: BioTask) -> float: |
| """Evaluate sequence similarity using multiple metrics.""" |
| prompt = task.prompt |
| if task.context: |
| prompt += f" {task.context}" |
| |
| output = self._get_model_output(model, tokenizer, prompt) |
| |
| if task.expected_answer is None: |
| return 0.5 |
| |
| output_seq = self._extract_sequence(output, task.task_type) |
| expected = task.expected_answer.strip().upper() |
| |
| if not output_seq or not expected: |
| return 0.0 |
| |
| |
| sm = SequenceMatcher(None, output_seq, expected) |
| similarity = sm.ratio() |
| |
| |
| |
| |
| return similarity |
| |
| def _eval_contains_substring(self, model, tokenizer, task: BioTask) -> float: |
| """Check if output contains expected motif.""" |
| prompt = task.prompt |
| if task.context: |
| prompt += f" {task.context}" |
| |
| output = self._get_model_output(model, tokenizer, prompt) |
| |
| if task.expected_answer is None: |
| return 0.5 |
| |
| expected = task.expected_answer.strip().upper() |
| output_seq = self._extract_sequence(output, task.task_type) |
| |
| if expected in output_seq: |
| return 1.0 |
| |
| |
| for i in range(len(expected) - 2): |
| sub = expected[i:i+3] |
| if sub in output_seq: |
| return 0.3 |
| |
| return 0.0 |
| |
| def _eval_exact_match(self, model, tokenizer, task: BioTask) -> float: |
| """Exact match evaluation.""" |
| prompt = task.prompt |
| if task.context: |
| prompt += f" {task.context}" |
| |
| output = self._get_model_output(model, tokenizer, prompt) |
| |
| if task.expected_answer is None: |
| return 0.5 |
| |
| |
| output_answer = self._extract_answer(output) |
| expected = task.expected_answer.strip() |
| |
| if output_answer == expected: |
| return 1.0 |
| |
| |
| try: |
| output_num = float(output_answer) |
| expected_num = float(expected) |
| if abs(output_num - expected_num) < 1: |
| return 0.5 |
| except (ValueError, TypeError): |
| pass |
| |
| return 0.0 |
| |
| def _eval_perplexity(self, model, tokenizer, task: BioTask) -> float: |
| """Evaluate perplexity on a sequence.""" |
| if task.target is None: |
| return 0.5 |
| |
| text = task.target |
| input_ids = tokenizer.encode(text, return_tensors="pt", max_length=self.max_length, truncation=True) |
| input_ids = input_ids.to(self.device) |
| |
| with torch.no_grad(): |
| outputs = model(input_ids, labels=input_ids) |
| loss = outputs.loss |
| |
| perplexity = torch.exp(loss).item() |
| |
| |
| |
| score = 1.0 / (1.0 + perplexity / 10.0) |
| |
| return score |
| |
| def _eval_rna_structure(self, model, tokenizer, task: BioTask) -> float: |
| """ |
| Evaluate RNA structure prediction. |
| Uses simplified dot-bracket notation comparison. |
| """ |
| prompt = task.prompt |
| if task.context: |
| prompt += f" {task.context}" |
| |
| output = self._get_model_output(model, tokenizer, prompt) |
| |
| |
| predicted = self._extract_structure(output) |
| |
| if not predicted: |
| return 0.0 |
| |
| |
| if task.expected_answer is None: |
| |
| balance = 0 |
| valid = True |
| for c in predicted: |
| if c == '(': |
| balance += 1 |
| elif c == ')': |
| balance -= 1 |
| if balance < 0: |
| valid = False |
| |
| if valid and balance == 0: |
| return 0.5 |
| return 0.0 |
| |
| expected = task.expected_answer |
| |
| |
| matches = sum(1 for a, b in zip(predicted, expected) if a == b) |
| return matches / max(len(predicted), len(expected)) |
| |
| def _extract_sequence(self, text: str, seq_type: str) -> str: |
| """Extract biological sequence from model output.""" |
| |
| text = text.replace("<mask>", "").replace("[MASK]", "") |
| text = text.replace("<s>", "").replace("</s>", "") |
| text = text.replace("[CLS]", "").replace("[SEP]", "") |
| |
| |
| if seq_type == "protein": |
| pattern = re.compile(r'[ACDEFGHIKLMNPQRSTVWY]+') |
| matches = pattern.findall(text.upper()) |
| if matches: |
| return max(matches, key=len) |
| return text.upper() |
| |
| |
| elif seq_type == "dna": |
| pattern = re.compile(r'[ACGT]+') |
| matches = pattern.findall(text.upper()) |
| if matches: |
| return max(matches, key=len) |
| return text.upper().replace('U', 'T') |
| |
| |
| elif seq_type == "rna": |
| pattern = re.compile(r'[ACGU]+') |
| matches = pattern.findall(text.upper()) |
| if matches: |
| return max(matches, key=len) |
| return text.upper().replace('T', 'U') |
| |
| return text.upper().strip() |
| |
| def _extract_answer(self, text: str) -> str: |
| """Extract a short answer from model output.""" |
| |
| numbers = re.findall(r'-?\d+', text) |
| if numbers: |
| return numbers[-1] |
| |
| |
| lines = [l.strip() for l in text.split('\n') if l.strip()] |
| if lines: |
| return lines[-1] |
| |
| return text.strip() |
| |
| def _extract_structure(self, text: str) -> str: |
| """Extract dot-bracket RNA structure notation.""" |
| pattern = re.compile(r'[\(\)\.]+') |
| matches = pattern.findall(text) |
| if matches: |
| return max(matches, key=len) |
| return "" |