AliSaadatV commited on
Commit
2b88f43
·
verified ·
1 Parent(s): a0d03b2

Upload evaluator.py

Browse files
Files changed (1) hide show
  1. evaluator.py +409 -0
evaluator.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Evaluator for biological language models on synthetic sequence tasks.
3
+ Supports masked language models (ESM-2, NT) and autoregressive models.
4
+ """
5
+
6
+ import re
7
+ import logging
8
+ from typing import List, Dict, Optional
9
+ import numpy as np
10
+ from transformers import AutoModelForMaskedLM, AutoTokenizer, EsmTokenizer
11
+ import torch
12
+ from difflib import SequenceMatcher
13
+
14
+ from .tasks import BioTask
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class BioEvaluator:
20
+ """Evaluates biological language models on sequence tasks."""
21
+
22
+ def __init__(self, device: str = "auto", max_length: int = 1024):
23
+ self.device = device if device != "auto" else ("cuda" if torch.cuda.is_available() else "cpu")
24
+ self.max_length = max_length
25
+ self._model_cache = {}
26
+ self._tokenizer_cache = {}
27
+
28
+ def _load_model(self, model_path: str):
29
+ """Load model with caching."""
30
+ if model_path not in self._model_cache:
31
+ logger.info(f"Loading model from {model_path}")
32
+ try:
33
+ model = AutoModelForMaskedLM.from_pretrained(
34
+ model_path,
35
+ torch_dtype=torch.bfloat16,
36
+ trust_remote_code=True,
37
+ )
38
+ except:
39
+ # Fallback if not standard masked LM
40
+ from transformers import AutoModel
41
+ model = AutoModel.from_pretrained(
42
+ model_path,
43
+ torch_dtype=torch.bfloat16,
44
+ trust_remote_code=True,
45
+ )
46
+
47
+ model = model.to(self.device)
48
+ model.eval()
49
+ self._model_cache[model_path] = model
50
+
51
+ return self._model_cache[model_path]
52
+
53
+ def _load_tokenizer(self, model_path: str):
54
+ """Load tokenizer with caching."""
55
+ if model_path not in self._tokenizer_cache:
56
+ logger.info(f"Loading tokenizer from {model_path}")
57
+ tokenizer = AutoTokenizer.from_pretrained(
58
+ model_path,
59
+ trust_remote_code=True,
60
+ )
61
+ self._tokenizer_cache[model_path] = tokenizer
62
+
63
+ return self._tokenizer_cache[model_path]
64
+
65
+ def evaluate_model(
66
+ self,
67
+ model_path: str,
68
+ tasks: List[BioTask],
69
+ ) -> Dict[str, float]:
70
+ """Evaluate a model on a list of tasks. Returns task_id -> score mapping."""
71
+ model = self._load_model(model_path)
72
+ tokenizer = self._load_tokenizer(model_path)
73
+
74
+ results = {}
75
+
76
+ for task in tasks:
77
+ try:
78
+ score = self._evaluate_single_task(model, tokenizer, task)
79
+ results[task.task_id] = score
80
+ except Exception as e:
81
+ logger.error(f"Error evaluating task {task.task_id}: {e}")
82
+ results[task.task_id] = 0.0
83
+
84
+ return results
85
+
86
+ def _evaluate_single_task(
87
+ self,
88
+ model: torch.nn.Module,
89
+ tokenizer,
90
+ task: BioTask,
91
+ ) -> float:
92
+ """Evaluate a single task."""
93
+
94
+ if task.evaluation_metric == "sequence_identity":
95
+ return self._eval_sequence_identity(model, tokenizer, task)
96
+
97
+ elif task.evaluation_metric == "sequence_similarity":
98
+ return self._eval_sequence_similarity(model, tokenizer, task)
99
+
100
+ elif task.evaluation_metric == "contains_substring":
101
+ return self._eval_contains_substring(model, tokenizer, task)
102
+
103
+ elif task.evaluation_metric == "exact_match":
104
+ return self._eval_exact_match(model, tokenizer, task)
105
+
106
+ elif task.evaluation_metric == "perplexity":
107
+ return self._eval_perplexity(model, tokenizer, task)
108
+
109
+ elif task.evaluation_metric == "rna_structure_similarity":
110
+ return self._eval_rna_structure(model, tokenizer, task)
111
+
112
+ else:
113
+ logger.warning(f"Unknown metric: {task.evaluation_metric}, defaulting to sequence similarity")
114
+ return self._eval_sequence_similarity(model, tokenizer, task)
115
+
116
+ def _get_model_output(self, model, tokenizer, prompt: str) -> str:
117
+ """Get model output for a prompt."""
118
+ # For masked LMs, we use the masked prediction approach
119
+ # For autoregressive models, we'd use generation
120
+
121
+ if task_has_mask := "<mask>" in prompt or "[MASK]" in prompt:
122
+ # Masked prediction task
123
+ return self._predict_masked(model, tokenizer, prompt)
124
+ else:
125
+ # For sequence continuation, try autoregressive generation if model supports it
126
+ return self._generate_sequence(model, tokenizer, prompt)
127
+
128
+ def _predict_masked(self, model, tokenizer, prompt: str) -> str:
129
+ """Predict masked tokens in a sequence."""
130
+ # Tokenize
131
+ tokens = tokenizer.tokenize(prompt)
132
+
133
+ # Find mask positions
134
+ mask_token = tokenizer.mask_token or "<mask>"
135
+ mask_positions = [i for i, t in enumerate(tokens) if t == mask_token or t == "[MASK]"]
136
+
137
+ if not mask_positions:
138
+ # No mask found, just return prompt
139
+ return prompt
140
+
141
+ # Convert to IDs
142
+ input_ids = tokenizer.encode(prompt, return_tensors="pt", max_length=self.max_length, truncation=True)
143
+ input_ids = input_ids.to(self.device)
144
+
145
+ # Get predictions
146
+ with torch.no_grad():
147
+ outputs = model(input_ids)
148
+ logits = outputs.logits
149
+
150
+ # Fill in masks
151
+ predicted_tokens = tokens.copy()
152
+ for pos in mask_positions:
153
+ mask_logits = logits[0, pos + 1] # +1 for CLS if present
154
+ predicted_id = torch.argmax(mask_logits).item()
155
+ predicted_token = tokenizer.convert_ids_to_tokens([predicted_id])[0]
156
+ predicted_tokens[pos] = predicted_token
157
+
158
+ # Reconstruct
159
+ return tokenizer.convert_tokens_to_string(predicted_tokens)
160
+
161
+ def _generate_sequence(self, model, tokenizer, prompt: str, max_new_tokens: int = 50) -> str:
162
+ """Generate a sequence continuation."""
163
+ # Simple greedy generation for masked LM models
164
+ # For true autoregressive models, this would use generate()
165
+
166
+ input_ids = tokenizer.encode(prompt, return_tensors="pt", max_length=self.max_length, truncation=True)
167
+ input_ids = input_ids.to(self.device)
168
+
169
+ generated = input_ids.clone()
170
+
171
+ # Greedy token-by-token generation
172
+ for _ in range(max_new_tokens):
173
+ with torch.no_grad():
174
+ outputs = model(generated)
175
+ logits = outputs.logits
176
+
177
+ # Get next token prediction
178
+ next_token_logits = logits[0, -1, :]
179
+ next_token_id = torch.argmax(next_token_logits).item()
180
+
181
+ # Append
182
+ next_token = torch.tensor([[next_token_id]], device=self.device)
183
+ generated = torch.cat([generated, next_token], dim=1)
184
+
185
+ # Check for EOS
186
+ if next_token_id == tokenizer.eos_token_id:
187
+ break
188
+
189
+ return tokenizer.decode(generated[0], skip_special_tokens=True)
190
+
191
+ def _eval_sequence_identity(self, model, tokenizer, task: BioTask) -> float:
192
+ """Evaluate exact sequence identity."""
193
+ prompt = task.prompt
194
+ if task.context:
195
+ prompt += f" {task.context}"
196
+
197
+ output = self._get_model_output(model, tokenizer, prompt)
198
+
199
+ if task.expected_answer is None:
200
+ return 0.5 # Default if no expected answer
201
+
202
+ # Extract sequence from output
203
+ output_seq = self._extract_sequence(output, task.task_type)
204
+ expected = task.expected_answer.strip().upper()
205
+
206
+ if not output_seq or not expected:
207
+ return 0.0
208
+
209
+ # Compute identity
210
+ matches = sum(1 for a, b in zip(output_seq, expected) if a == b)
211
+ length = max(len(output_seq), len(expected))
212
+
213
+ return matches / length if length > 0 else 0.0
214
+
215
+ def _eval_sequence_similarity(self, model, tokenizer, task: BioTask) -> float:
216
+ """Evaluate sequence similarity using multiple metrics."""
217
+ prompt = task.prompt
218
+ if task.context:
219
+ prompt += f" {task.context}"
220
+
221
+ output = self._get_model_output(model, tokenizer, prompt)
222
+
223
+ if task.expected_answer is None:
224
+ return 0.5
225
+
226
+ output_seq = self._extract_sequence(output, task.task_type)
227
+ expected = task.expected_answer.strip().upper()
228
+
229
+ if not output_seq or not expected:
230
+ return 0.0
231
+
232
+ # SequenceMatcher ratio
233
+ sm = SequenceMatcher(None, output_seq, expected)
234
+ similarity = sm.ratio()
235
+
236
+ # Also compute local alignment score (simplified)
237
+ # Could use Bio.pairwise2 or biopython for full alignment
238
+
239
+ return similarity
240
+
241
+ def _eval_contains_substring(self, model, tokenizer, task: BioTask) -> float:
242
+ """Check if output contains expected motif."""
243
+ prompt = task.prompt
244
+ if task.context:
245
+ prompt += f" {task.context}"
246
+
247
+ output = self._get_model_output(model, tokenizer, prompt)
248
+
249
+ if task.expected_answer is None:
250
+ return 0.5
251
+
252
+ expected = task.expected_answer.strip().upper()
253
+ output_seq = self._extract_sequence(output, task.task_type)
254
+
255
+ if expected in output_seq:
256
+ return 1.0
257
+
258
+ # Partial match
259
+ for i in range(len(expected) - 2):
260
+ sub = expected[i:i+3]
261
+ if sub in output_seq:
262
+ return 0.3
263
+
264
+ return 0.0
265
+
266
+ def _eval_exact_match(self, model, tokenizer, task: BioTask) -> float:
267
+ """Exact match evaluation."""
268
+ prompt = task.prompt
269
+ if task.context:
270
+ prompt += f" {task.context}"
271
+
272
+ output = self._get_model_output(model, tokenizer, prompt)
273
+
274
+ if task.expected_answer is None:
275
+ return 0.5
276
+
277
+ # Extract answer from output
278
+ output_answer = self._extract_answer(output)
279
+ expected = task.expected_answer.strip()
280
+
281
+ if output_answer == expected:
282
+ return 1.0
283
+
284
+ # Numeric approximate match
285
+ try:
286
+ output_num = float(output_answer)
287
+ expected_num = float(expected)
288
+ if abs(output_num - expected_num) < 1:
289
+ return 0.5
290
+ except (ValueError, TypeError):
291
+ pass
292
+
293
+ return 0.0
294
+
295
+ def _eval_perplexity(self, model, tokenizer, task: BioTask) -> float:
296
+ """Evaluate perplexity on a sequence."""
297
+ if task.target is None:
298
+ return 0.5
299
+
300
+ text = task.target
301
+ input_ids = tokenizer.encode(text, return_tensors="pt", max_length=self.max_length, truncation=True)
302
+ input_ids = input_ids.to(self.device)
303
+
304
+ with torch.no_grad():
305
+ outputs = model(input_ids, labels=input_ids)
306
+ loss = outputs.loss
307
+
308
+ perplexity = torch.exp(loss).item()
309
+
310
+ # Convert to score (lower perplexity = higher score)
311
+ # Typical perplexity for protein LMs is 5-20
312
+ score = 1.0 / (1.0 + perplexity / 10.0)
313
+
314
+ return score
315
+
316
+ def _eval_rna_structure(self, model, tokenizer, task: BioTask) -> float:
317
+ """
318
+ Evaluate RNA structure prediction.
319
+ Uses simplified dot-bracket notation comparison.
320
+ """
321
+ prompt = task.prompt
322
+ if task.context:
323
+ prompt += f" {task.context}"
324
+
325
+ output = self._get_model_output(model, tokenizer, prompt)
326
+
327
+ # Extract predicted structure (dot-bracket notation)
328
+ predicted = self._extract_structure(output)
329
+
330
+ if not predicted:
331
+ return 0.0
332
+
333
+ # For generated tasks without expected structure, just check validity
334
+ if task.expected_answer is None:
335
+ # Check if dot-bracket is balanced
336
+ balance = 0
337
+ valid = True
338
+ for c in predicted:
339
+ if c == '(':
340
+ balance += 1
341
+ elif c == ')':
342
+ balance -= 1
343
+ if balance < 0:
344
+ valid = False
345
+
346
+ if valid and balance == 0:
347
+ return 0.5
348
+ return 0.0
349
+
350
+ expected = task.expected_answer
351
+
352
+ # Compare structures
353
+ matches = sum(1 for a, b in zip(predicted, expected) if a == b)
354
+ return matches / max(len(predicted), len(expected))
355
+
356
+ def _extract_sequence(self, text: str, seq_type: str) -> str:
357
+ """Extract biological sequence from model output."""
358
+ # Remove special tokens and whitespace
359
+ text = text.replace("<mask>", "").replace("[MASK]", "")
360
+ text = text.replace("<s>", "").replace("</s>", "")
361
+ text = text.replace("[CLS]", "").replace("[SEP]", "")
362
+
363
+ # For proteins, look for uppercase amino acid sequences
364
+ if seq_type == "protein":
365
+ pattern = re.compile(r'[ACDEFGHIKLMNPQRSTVWY]+')
366
+ matches = pattern.findall(text.upper())
367
+ if matches:
368
+ return max(matches, key=len)
369
+ return text.upper()
370
+
371
+ # For DNA
372
+ elif seq_type == "dna":
373
+ pattern = re.compile(r'[ACGT]+')
374
+ matches = pattern.findall(text.upper())
375
+ if matches:
376
+ return max(matches, key=len)
377
+ return text.upper().replace('U', 'T')
378
+
379
+ # For RNA
380
+ elif seq_type == "rna":
381
+ pattern = re.compile(r'[ACGU]+')
382
+ matches = pattern.findall(text.upper())
383
+ if matches:
384
+ return max(matches, key=len)
385
+ return text.upper().replace('T', 'U')
386
+
387
+ return text.upper().strip()
388
+
389
+ def _extract_answer(self, text: str) -> str:
390
+ """Extract a short answer from model output."""
391
+ # Try to find a number
392
+ numbers = re.findall(r'-?\d+', text)
393
+ if numbers:
394
+ return numbers[-1] # Last number is often the answer
395
+
396
+ # Or take the last non-empty line
397
+ lines = [l.strip() for l in text.split('\n') if l.strip()]
398
+ if lines:
399
+ return lines[-1]
400
+
401
+ return text.strip()
402
+
403
+ def _extract_structure(self, text: str) -> str:
404
+ """Extract dot-bracket RNA structure notation."""
405
+ pattern = re.compile(r'[\(\)\.]+')
406
+ matches = pattern.findall(text)
407
+ if matches:
408
+ return max(matches, key=len)
409
+ return ""