""" Evaluator for biological language models on synthetic sequence tasks. Supports masked language models (ESM-2, NT) and autoregressive models. """ import re import logging from typing import List, Dict, Optional import numpy as np from transformers import AutoModelForMaskedLM, AutoTokenizer, EsmTokenizer import torch from difflib import SequenceMatcher from .tasks import BioTask logger = logging.getLogger(__name__) class BioEvaluator: """Evaluates biological language models on sequence tasks.""" def __init__(self, device: str = "auto", max_length: int = 1024): self.device = device if device != "auto" else ("cuda" if torch.cuda.is_available() else "cpu") self.max_length = max_length self._model_cache = {} self._tokenizer_cache = {} def _load_model(self, model_path: str): """Load model with caching.""" if model_path not in self._model_cache: logger.info(f"Loading model from {model_path}") try: model = AutoModelForMaskedLM.from_pretrained( model_path, torch_dtype=torch.bfloat16, trust_remote_code=True, ) except: # Fallback if not standard masked LM from transformers import AutoModel model = AutoModel.from_pretrained( model_path, torch_dtype=torch.bfloat16, trust_remote_code=True, ) model = model.to(self.device) model.eval() self._model_cache[model_path] = model return self._model_cache[model_path] def _load_tokenizer(self, model_path: str): """Load tokenizer with caching.""" if model_path not in self._tokenizer_cache: logger.info(f"Loading tokenizer from {model_path}") tokenizer = AutoTokenizer.from_pretrained( model_path, trust_remote_code=True, ) self._tokenizer_cache[model_path] = tokenizer return self._tokenizer_cache[model_path] def evaluate_model( self, model_path: str, tasks: List[BioTask], ) -> Dict[str, float]: """Evaluate a model on a list of tasks. Returns task_id -> score mapping.""" model = self._load_model(model_path) tokenizer = self._load_tokenizer(model_path) results = {} for task in tasks: try: score = self._evaluate_single_task(model, tokenizer, task) results[task.task_id] = score except Exception as e: logger.error(f"Error evaluating task {task.task_id}: {e}") results[task.task_id] = 0.0 return results def _evaluate_single_task( self, model: torch.nn.Module, tokenizer, task: BioTask, ) -> float: """Evaluate a single task.""" if task.evaluation_metric == "sequence_identity": return self._eval_sequence_identity(model, tokenizer, task) elif task.evaluation_metric == "sequence_similarity": return self._eval_sequence_similarity(model, tokenizer, task) elif task.evaluation_metric == "contains_substring": return self._eval_contains_substring(model, tokenizer, task) elif task.evaluation_metric == "exact_match": return self._eval_exact_match(model, tokenizer, task) elif task.evaluation_metric == "perplexity": return self._eval_perplexity(model, tokenizer, task) elif task.evaluation_metric == "rna_structure_similarity": return self._eval_rna_structure(model, tokenizer, task) else: logger.warning(f"Unknown metric: {task.evaluation_metric}, defaulting to sequence similarity") return self._eval_sequence_similarity(model, tokenizer, task) def _get_model_output(self, model, tokenizer, prompt: str) -> str: """Get model output for a prompt.""" # For masked LMs, we use the masked prediction approach # For autoregressive models, we'd use generation if task_has_mask := "" in prompt or "[MASK]" in prompt: # Masked prediction task return self._predict_masked(model, tokenizer, prompt) else: # For sequence continuation, try autoregressive generation if model supports it return self._generate_sequence(model, tokenizer, prompt) def _predict_masked(self, model, tokenizer, prompt: str) -> str: """Predict masked tokens in a sequence.""" # Tokenize tokens = tokenizer.tokenize(prompt) # Find mask positions mask_token = tokenizer.mask_token or "" mask_positions = [i for i, t in enumerate(tokens) if t == mask_token or t == "[MASK]"] if not mask_positions: # No mask found, just return prompt return prompt # Convert to IDs input_ids = tokenizer.encode(prompt, return_tensors="pt", max_length=self.max_length, truncation=True) input_ids = input_ids.to(self.device) # Get predictions with torch.no_grad(): outputs = model(input_ids) logits = outputs.logits # Fill in masks predicted_tokens = tokens.copy() for pos in mask_positions: mask_logits = logits[0, pos + 1] # +1 for CLS if present predicted_id = torch.argmax(mask_logits).item() predicted_token = tokenizer.convert_ids_to_tokens([predicted_id])[0] predicted_tokens[pos] = predicted_token # Reconstruct return tokenizer.convert_tokens_to_string(predicted_tokens) def _generate_sequence(self, model, tokenizer, prompt: str, max_new_tokens: int = 50) -> str: """Generate a sequence continuation.""" # Simple greedy generation for masked LM models # For true autoregressive models, this would use generate() input_ids = tokenizer.encode(prompt, return_tensors="pt", max_length=self.max_length, truncation=True) input_ids = input_ids.to(self.device) generated = input_ids.clone() # Greedy token-by-token generation for _ in range(max_new_tokens): with torch.no_grad(): outputs = model(generated) logits = outputs.logits # Get next token prediction next_token_logits = logits[0, -1, :] next_token_id = torch.argmax(next_token_logits).item() # Append next_token = torch.tensor([[next_token_id]], device=self.device) generated = torch.cat([generated, next_token], dim=1) # Check for EOS if next_token_id == tokenizer.eos_token_id: break return tokenizer.decode(generated[0], skip_special_tokens=True) def _eval_sequence_identity(self, model, tokenizer, task: BioTask) -> float: """Evaluate exact sequence identity.""" prompt = task.prompt if task.context: prompt += f" {task.context}" output = self._get_model_output(model, tokenizer, prompt) if task.expected_answer is None: return 0.5 # Default if no expected answer # Extract sequence from output output_seq = self._extract_sequence(output, task.task_type) expected = task.expected_answer.strip().upper() if not output_seq or not expected: return 0.0 # Compute identity matches = sum(1 for a, b in zip(output_seq, expected) if a == b) length = max(len(output_seq), len(expected)) return matches / length if length > 0 else 0.0 def _eval_sequence_similarity(self, model, tokenizer, task: BioTask) -> float: """Evaluate sequence similarity using multiple metrics.""" prompt = task.prompt if task.context: prompt += f" {task.context}" output = self._get_model_output(model, tokenizer, prompt) if task.expected_answer is None: return 0.5 output_seq = self._extract_sequence(output, task.task_type) expected = task.expected_answer.strip().upper() if not output_seq or not expected: return 0.0 # SequenceMatcher ratio sm = SequenceMatcher(None, output_seq, expected) similarity = sm.ratio() # Also compute local alignment score (simplified) # Could use Bio.pairwise2 or biopython for full alignment return similarity def _eval_contains_substring(self, model, tokenizer, task: BioTask) -> float: """Check if output contains expected motif.""" prompt = task.prompt if task.context: prompt += f" {task.context}" output = self._get_model_output(model, tokenizer, prompt) if task.expected_answer is None: return 0.5 expected = task.expected_answer.strip().upper() output_seq = self._extract_sequence(output, task.task_type) if expected in output_seq: return 1.0 # Partial match for i in range(len(expected) - 2): sub = expected[i:i+3] if sub in output_seq: return 0.3 return 0.0 def _eval_exact_match(self, model, tokenizer, task: BioTask) -> float: """Exact match evaluation.""" prompt = task.prompt if task.context: prompt += f" {task.context}" output = self._get_model_output(model, tokenizer, prompt) if task.expected_answer is None: return 0.5 # Extract answer from output output_answer = self._extract_answer(output) expected = task.expected_answer.strip() if output_answer == expected: return 1.0 # Numeric approximate match try: output_num = float(output_answer) expected_num = float(expected) if abs(output_num - expected_num) < 1: return 0.5 except (ValueError, TypeError): pass return 0.0 def _eval_perplexity(self, model, tokenizer, task: BioTask) -> float: """Evaluate perplexity on a sequence.""" if task.target is None: return 0.5 text = task.target input_ids = tokenizer.encode(text, return_tensors="pt", max_length=self.max_length, truncation=True) input_ids = input_ids.to(self.device) with torch.no_grad(): outputs = model(input_ids, labels=input_ids) loss = outputs.loss perplexity = torch.exp(loss).item() # Convert to score (lower perplexity = higher score) # Typical perplexity for protein LMs is 5-20 score = 1.0 / (1.0 + perplexity / 10.0) return score def _eval_rna_structure(self, model, tokenizer, task: BioTask) -> float: """ Evaluate RNA structure prediction. Uses simplified dot-bracket notation comparison. """ prompt = task.prompt if task.context: prompt += f" {task.context}" output = self._get_model_output(model, tokenizer, prompt) # Extract predicted structure (dot-bracket notation) predicted = self._extract_structure(output) if not predicted: return 0.0 # For generated tasks without expected structure, just check validity if task.expected_answer is None: # Check if dot-bracket is balanced balance = 0 valid = True for c in predicted: if c == '(': balance += 1 elif c == ')': balance -= 1 if balance < 0: valid = False if valid and balance == 0: return 0.5 return 0.0 expected = task.expected_answer # Compare structures matches = sum(1 for a, b in zip(predicted, expected) if a == b) return matches / max(len(predicted), len(expected)) def _extract_sequence(self, text: str, seq_type: str) -> str: """Extract biological sequence from model output.""" # Remove special tokens and whitespace text = text.replace("", "").replace("[MASK]", "") text = text.replace("", "").replace("", "") text = text.replace("[CLS]", "").replace("[SEP]", "") # For proteins, look for uppercase amino acid sequences if seq_type == "protein": pattern = re.compile(r'[ACDEFGHIKLMNPQRSTVWY]+') matches = pattern.findall(text.upper()) if matches: return max(matches, key=len) return text.upper() # For DNA elif seq_type == "dna": pattern = re.compile(r'[ACGT]+') matches = pattern.findall(text.upper()) if matches: return max(matches, key=len) return text.upper().replace('U', 'T') # For RNA elif seq_type == "rna": pattern = re.compile(r'[ACGU]+') matches = pattern.findall(text.upper()) if matches: return max(matches, key=len) return text.upper().replace('T', 'U') return text.upper().strip() def _extract_answer(self, text: str) -> str: """Extract a short answer from model output.""" # Try to find a number numbers = re.findall(r'-?\d+', text) if numbers: return numbers[-1] # Last number is often the answer # Or take the last non-empty line lines = [l.strip() for l in text.split('\n') if l.strip()] if lines: return lines[-1] return text.strip() def _extract_structure(self, text: str) -> str: """Extract dot-bracket RNA structure notation.""" pattern = re.compile(r'[\(\)\.]+') matches = pattern.findall(text) if matches: return max(matches, key=len) return ""