bio-acdc / evaluator.py
AliSaadatV's picture
Upload evaluator.py
2b88f43 verified
"""
Evaluator for biological language models on synthetic sequence tasks.
Supports masked language models (ESM-2, NT) and autoregressive models.
"""
import re
import logging
from typing import List, Dict, Optional
import numpy as np
from transformers import AutoModelForMaskedLM, AutoTokenizer, EsmTokenizer
import torch
from difflib import SequenceMatcher
from .tasks import BioTask
logger = logging.getLogger(__name__)
class BioEvaluator:
"""Evaluates biological language models on sequence tasks."""
def __init__(self, device: str = "auto", max_length: int = 1024):
self.device = device if device != "auto" else ("cuda" if torch.cuda.is_available() else "cpu")
self.max_length = max_length
self._model_cache = {}
self._tokenizer_cache = {}
def _load_model(self, model_path: str):
"""Load model with caching."""
if model_path not in self._model_cache:
logger.info(f"Loading model from {model_path}")
try:
model = AutoModelForMaskedLM.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
)
except:
# Fallback if not standard masked LM
from transformers import AutoModel
model = AutoModel.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
)
model = model.to(self.device)
model.eval()
self._model_cache[model_path] = model
return self._model_cache[model_path]
def _load_tokenizer(self, model_path: str):
"""Load tokenizer with caching."""
if model_path not in self._tokenizer_cache:
logger.info(f"Loading tokenizer from {model_path}")
tokenizer = AutoTokenizer.from_pretrained(
model_path,
trust_remote_code=True,
)
self._tokenizer_cache[model_path] = tokenizer
return self._tokenizer_cache[model_path]
def evaluate_model(
self,
model_path: str,
tasks: List[BioTask],
) -> Dict[str, float]:
"""Evaluate a model on a list of tasks. Returns task_id -> score mapping."""
model = self._load_model(model_path)
tokenizer = self._load_tokenizer(model_path)
results = {}
for task in tasks:
try:
score = self._evaluate_single_task(model, tokenizer, task)
results[task.task_id] = score
except Exception as e:
logger.error(f"Error evaluating task {task.task_id}: {e}")
results[task.task_id] = 0.0
return results
def _evaluate_single_task(
self,
model: torch.nn.Module,
tokenizer,
task: BioTask,
) -> float:
"""Evaluate a single task."""
if task.evaluation_metric == "sequence_identity":
return self._eval_sequence_identity(model, tokenizer, task)
elif task.evaluation_metric == "sequence_similarity":
return self._eval_sequence_similarity(model, tokenizer, task)
elif task.evaluation_metric == "contains_substring":
return self._eval_contains_substring(model, tokenizer, task)
elif task.evaluation_metric == "exact_match":
return self._eval_exact_match(model, tokenizer, task)
elif task.evaluation_metric == "perplexity":
return self._eval_perplexity(model, tokenizer, task)
elif task.evaluation_metric == "rna_structure_similarity":
return self._eval_rna_structure(model, tokenizer, task)
else:
logger.warning(f"Unknown metric: {task.evaluation_metric}, defaulting to sequence similarity")
return self._eval_sequence_similarity(model, tokenizer, task)
def _get_model_output(self, model, tokenizer, prompt: str) -> str:
"""Get model output for a prompt."""
# For masked LMs, we use the masked prediction approach
# For autoregressive models, we'd use generation
if task_has_mask := "<mask>" in prompt or "[MASK]" in prompt:
# Masked prediction task
return self._predict_masked(model, tokenizer, prompt)
else:
# For sequence continuation, try autoregressive generation if model supports it
return self._generate_sequence(model, tokenizer, prompt)
def _predict_masked(self, model, tokenizer, prompt: str) -> str:
"""Predict masked tokens in a sequence."""
# Tokenize
tokens = tokenizer.tokenize(prompt)
# Find mask positions
mask_token = tokenizer.mask_token or "<mask>"
mask_positions = [i for i, t in enumerate(tokens) if t == mask_token or t == "[MASK]"]
if not mask_positions:
# No mask found, just return prompt
return prompt
# Convert to IDs
input_ids = tokenizer.encode(prompt, return_tensors="pt", max_length=self.max_length, truncation=True)
input_ids = input_ids.to(self.device)
# Get predictions
with torch.no_grad():
outputs = model(input_ids)
logits = outputs.logits
# Fill in masks
predicted_tokens = tokens.copy()
for pos in mask_positions:
mask_logits = logits[0, pos + 1] # +1 for CLS if present
predicted_id = torch.argmax(mask_logits).item()
predicted_token = tokenizer.convert_ids_to_tokens([predicted_id])[0]
predicted_tokens[pos] = predicted_token
# Reconstruct
return tokenizer.convert_tokens_to_string(predicted_tokens)
def _generate_sequence(self, model, tokenizer, prompt: str, max_new_tokens: int = 50) -> str:
"""Generate a sequence continuation."""
# Simple greedy generation for masked LM models
# For true autoregressive models, this would use generate()
input_ids = tokenizer.encode(prompt, return_tensors="pt", max_length=self.max_length, truncation=True)
input_ids = input_ids.to(self.device)
generated = input_ids.clone()
# Greedy token-by-token generation
for _ in range(max_new_tokens):
with torch.no_grad():
outputs = model(generated)
logits = outputs.logits
# Get next token prediction
next_token_logits = logits[0, -1, :]
next_token_id = torch.argmax(next_token_logits).item()
# Append
next_token = torch.tensor([[next_token_id]], device=self.device)
generated = torch.cat([generated, next_token], dim=1)
# Check for EOS
if next_token_id == tokenizer.eos_token_id:
break
return tokenizer.decode(generated[0], skip_special_tokens=True)
def _eval_sequence_identity(self, model, tokenizer, task: BioTask) -> float:
"""Evaluate exact sequence identity."""
prompt = task.prompt
if task.context:
prompt += f" {task.context}"
output = self._get_model_output(model, tokenizer, prompt)
if task.expected_answer is None:
return 0.5 # Default if no expected answer
# Extract sequence from output
output_seq = self._extract_sequence(output, task.task_type)
expected = task.expected_answer.strip().upper()
if not output_seq or not expected:
return 0.0
# Compute identity
matches = sum(1 for a, b in zip(output_seq, expected) if a == b)
length = max(len(output_seq), len(expected))
return matches / length if length > 0 else 0.0
def _eval_sequence_similarity(self, model, tokenizer, task: BioTask) -> float:
"""Evaluate sequence similarity using multiple metrics."""
prompt = task.prompt
if task.context:
prompt += f" {task.context}"
output = self._get_model_output(model, tokenizer, prompt)
if task.expected_answer is None:
return 0.5
output_seq = self._extract_sequence(output, task.task_type)
expected = task.expected_answer.strip().upper()
if not output_seq or not expected:
return 0.0
# SequenceMatcher ratio
sm = SequenceMatcher(None, output_seq, expected)
similarity = sm.ratio()
# Also compute local alignment score (simplified)
# Could use Bio.pairwise2 or biopython for full alignment
return similarity
def _eval_contains_substring(self, model, tokenizer, task: BioTask) -> float:
"""Check if output contains expected motif."""
prompt = task.prompt
if task.context:
prompt += f" {task.context}"
output = self._get_model_output(model, tokenizer, prompt)
if task.expected_answer is None:
return 0.5
expected = task.expected_answer.strip().upper()
output_seq = self._extract_sequence(output, task.task_type)
if expected in output_seq:
return 1.0
# Partial match
for i in range(len(expected) - 2):
sub = expected[i:i+3]
if sub in output_seq:
return 0.3
return 0.0
def _eval_exact_match(self, model, tokenizer, task: BioTask) -> float:
"""Exact match evaluation."""
prompt = task.prompt
if task.context:
prompt += f" {task.context}"
output = self._get_model_output(model, tokenizer, prompt)
if task.expected_answer is None:
return 0.5
# Extract answer from output
output_answer = self._extract_answer(output)
expected = task.expected_answer.strip()
if output_answer == expected:
return 1.0
# Numeric approximate match
try:
output_num = float(output_answer)
expected_num = float(expected)
if abs(output_num - expected_num) < 1:
return 0.5
except (ValueError, TypeError):
pass
return 0.0
def _eval_perplexity(self, model, tokenizer, task: BioTask) -> float:
"""Evaluate perplexity on a sequence."""
if task.target is None:
return 0.5
text = task.target
input_ids = tokenizer.encode(text, return_tensors="pt", max_length=self.max_length, truncation=True)
input_ids = input_ids.to(self.device)
with torch.no_grad():
outputs = model(input_ids, labels=input_ids)
loss = outputs.loss
perplexity = torch.exp(loss).item()
# Convert to score (lower perplexity = higher score)
# Typical perplexity for protein LMs is 5-20
score = 1.0 / (1.0 + perplexity / 10.0)
return score
def _eval_rna_structure(self, model, tokenizer, task: BioTask) -> float:
"""
Evaluate RNA structure prediction.
Uses simplified dot-bracket notation comparison.
"""
prompt = task.prompt
if task.context:
prompt += f" {task.context}"
output = self._get_model_output(model, tokenizer, prompt)
# Extract predicted structure (dot-bracket notation)
predicted = self._extract_structure(output)
if not predicted:
return 0.0
# For generated tasks without expected structure, just check validity
if task.expected_answer is None:
# Check if dot-bracket is balanced
balance = 0
valid = True
for c in predicted:
if c == '(':
balance += 1
elif c == ')':
balance -= 1
if balance < 0:
valid = False
if valid and balance == 0:
return 0.5
return 0.0
expected = task.expected_answer
# Compare structures
matches = sum(1 for a, b in zip(predicted, expected) if a == b)
return matches / max(len(predicted), len(expected))
def _extract_sequence(self, text: str, seq_type: str) -> str:
"""Extract biological sequence from model output."""
# Remove special tokens and whitespace
text = text.replace("<mask>", "").replace("[MASK]", "")
text = text.replace("<s>", "").replace("</s>", "")
text = text.replace("[CLS]", "").replace("[SEP]", "")
# For proteins, look for uppercase amino acid sequences
if seq_type == "protein":
pattern = re.compile(r'[ACDEFGHIKLMNPQRSTVWY]+')
matches = pattern.findall(text.upper())
if matches:
return max(matches, key=len)
return text.upper()
# For DNA
elif seq_type == "dna":
pattern = re.compile(r'[ACGT]+')
matches = pattern.findall(text.upper())
if matches:
return max(matches, key=len)
return text.upper().replace('U', 'T')
# For RNA
elif seq_type == "rna":
pattern = re.compile(r'[ACGU]+')
matches = pattern.findall(text.upper())
if matches:
return max(matches, key=len)
return text.upper().replace('T', 'U')
return text.upper().strip()
def _extract_answer(self, text: str) -> str:
"""Extract a short answer from model output."""
# Try to find a number
numbers = re.findall(r'-?\d+', text)
if numbers:
return numbers[-1] # Last number is often the answer
# Or take the last non-empty line
lines = [l.strip() for l in text.split('\n') if l.strip()]
if lines:
return lines[-1]
return text.strip()
def _extract_structure(self, text: str) -> str:
"""Extract dot-bracket RNA structure notation."""
pattern = re.compile(r'[\(\)\.]+')
matches = pattern.findall(text)
if matches:
return max(matches, key=len)
return ""