bio-acdc / tasks.py
AliSaadatV's picture
Upload tasks.py
a0d03b2 verified
"""
Biological sequence task generation and pool management.
Generates synthetic tasks for evaluating biological language models.
"""
import os
import json
import logging
import random
import numpy as np
from typing import List, Dict, Optional, Any
from dataclasses import dataclass, field
from pathlib import Path
logger = logging.getLogger(__name__)
# Standard amino acids
AMINO_ACIDS = list("ACDEFGHIKLMNPQRSTVWY")
# Standard nucleotides
NUCLEOTIDES = list("ACGT")
RNA_NUCLEOTIDES = list("ACGU")
def generate_random_protein(length: int) -> str:
"""Generate a random protein sequence."""
return "".join(random.choices(AMINO_ACIDS, k=length))
def generate_random_dna(length: int) -> str:
"""Generate a random DNA sequence."""
return "".join(random.choices(NUCLEOTIDES, k=length))
def generate_random_rna(length: int) -> str:
"""Generate a random RNA sequence."""
return "".join(random.choices(RNA_NUCLEOTIDES, k=length))
def add_motif(sequence: str, motif: str, position: Optional[int] = None) -> str:
"""Insert a motif into a sequence at given position."""
if position is None:
position = random.randint(0, len(sequence) - len(motif))
return sequence[:position] + motif + sequence[position + len(motif):]
@dataclass
class BioTask:
"""A biological sequence evaluation task."""
task_id: str
task_type: str # "protein", "dna", "rna"
task_family: str # e.g., "motif_recognition", "structure_prediction", "function_prediction"
# Task content
prompt: str # Instruction for the model
context: Optional[str] = None # Additional context (e.g., partial sequence)
target: Optional[str] = None # Expected output
# Evaluation
evaluation_metric: str = "sequence_identity" # How to score responses
expected_answer: Optional[str] = None
# Metadata
difficulty: float = 0.5
generation_seed: Optional[int] = None
def to_dict(self) -> Dict[str, Any]:
return {
"task_id": self.task_id,
"task_type": self.task_type,
"task_family": self.task_family,
"prompt": self.prompt,
"context": self.context,
"target": self.target,
"evaluation_metric": self.evaluation_metric,
"expected_answer": self.expected_answer,
"difficulty": self.difficulty,
}
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "BioTask":
return cls(**data)
class BioTaskPool:
"""Manages a pool of biological evaluation tasks."""
# Known biological motifs for task generation
PROTEIN_MOTIFS = {
"nuclear_localization": "PKKKRKV",
"atp_binding": "GXGXXG",
"zinc_finger": "CCHH",
"helix_turn_helix": "LXXLL",
"transmembrane": "AVLIVF",
}
DNA_MOTIFS = {
"tata_box": "TATAAA",
"gc_rich": "GCGCGC",
"promoter": "CAAT",
"terminator": "AATAAA",
}
RNA_MOTIFS = {
"shine_dalgarno": "AGGAGGU",
"polya_signal": "AAUAAA",
"kozak": "GCCRCC",
}
def __init__(self, seed: int = 42):
self.rng = np.random.RandomState(seed)
self.tasks: List[BioTask] = []
self._initialize_base_tasks()
def _initialize_base_tasks(self) -> None:
"""Create initial set of basic tasks."""
# Protein tasks
for i, (name, motif) in enumerate(self.PROTEIN_MOTIFS.items()):
task = BioTask(
task_id=f"protein_motif_{name}_{i}",
task_type="protein",
task_family="motif_recognition",
prompt=f"Identify if this protein sequence contains a {name} motif: ",
evaluation_metric="contains_substring",
expected_answer=motif,
difficulty=0.3 + i * 0.1,
)
self.tasks.append(task)
# DNA tasks
for i, (name, motif) in enumerate(self.DNA_MOTIFS.items()):
task = BioTask(
task_id=f"dna_motif_{name}_{i}",
task_type="dna",
task_family="motif_recognition",
prompt=f"Identify if this DNA sequence contains a {name} motif: ",
evaluation_metric="contains_substring",
expected_answer=motif,
difficulty=0.3 + i * 0.1,
)
self.tasks.append(task)
# Sequence completion tasks
for seq_type, alphabet in [("protein", AMINO_ACIDS), ("dna", NUCLEOTIDES), ("rna", RNA_NUCLEOTIDES)]:
for length in [10, 20, 50]:
target = "".join(self.rng.choice(alphabet, size=length))
prefix = target[:length // 2]
task = BioTask(
task_id=f"{seq_type}_complete_len{length}_{hash(target) % 10000}",
task_type=seq_type,
task_family="sequence_completion",
prompt=f"Complete this {seq_type} sequence: {prefix}",
context=prefix,
target=target,
expected_answer=target,
evaluation_metric="sequence_similarity",
difficulty=min(length / 100.0, 0.9),
)
self.tasks.append(task)
logger.info(f"Initialized task pool with {len(self.tasks)} base tasks")
def get_tasks(self, n: int = 10, filter_type: Optional[str] = None) -> List[BioTask]:
"""Sample n tasks from the pool."""
available = self.tasks
if filter_type:
available = [t for t in available if t.task_type == filter_type]
if len(available) <= n:
return available
return self.rng.choice(available, size=n, replace=False).tolist()
def generate_new_tasks(
self,
archive: Any,
num_tasks: int = 5,
difficulty_weights: Optional[Dict[str, float]] = None,
) -> List[BioTask]:
"""Generate new tasks targeting weaknesses in the archive."""
new_tasks = []
# Analyze which sequence types need more challenge
seq_types = ["protein", "dna", "rna"]
for i in range(num_tasks):
seq_type = random.choice(seq_types)
# Generate harder tasks (longer sequences, more complex motifs)
if seq_type == "protein":
length = random.randint(50, 200)
sequence = generate_random_protein(length)
# Add a known motif at random position
motif_name = random.choice(list(self.PROTEIN_MOTIFS.keys()))
motif = self.PROTEIN_MOTIFS[motif_name]
pos = random.randint(0, length - len(motif))
sequence = sequence[:pos] + motif + sequence[pos + len(motif):]
task = BioTask(
task_id=f"protein_gen_{len(self.tasks) + i}_{hash(sequence) % 10000}",
task_type="protein",
task_family="motif_localization",
prompt=f"Find the position of the {motif_name} motif in this protein: {sequence}",
context=sequence,
target=str(pos),
expected_answer=str(pos),
evaluation_metric="exact_match",
difficulty=0.6 + random.random() * 0.3,
)
elif seq_type == "dna":
length = random.randint(100, 500)
sequence = generate_random_dna(length)
motif_name = random.choice(list(self.DNA_MOTIFS.keys()))
motif = self.DNA_MOTIFS[motif_name]
pos = random.randint(0, length - len(motif))
sequence = sequence[:pos] + motif + sequence[pos + len(motif):]
task = BioTask(
task_id=f"dna_gen_{len(self.tasks) + i}_{hash(sequence) % 10000}",
task_type="dna",
task_family="regulatory_element_detection",
prompt=f"Find the {motif_name} regulatory element in: {sequence}",
context=sequence,
target=str(pos),
expected_answer=str(pos),
evaluation_metric="exact_match",
difficulty=0.5 + random.random() * 0.3,
)
else: # rna
length = random.randint(50, 300)
sequence = generate_random_rna(length)
task = BioTask(
task_id=f"rna_gen_{len(self.tasks) + i}_{hash(sequence) % 10000}",
task_type="rna",
task_family="structure_prediction",
prompt=f"Predict the secondary structure of this RNA: {sequence}",
context=sequence,
evaluation_metric="rna_structure_similarity",
difficulty=0.7 + random.random() * 0.2,
)
new_tasks.append(task)
# Add to pool
self.tasks.extend(new_tasks)
logger.info(f"Generated {len(new_tasks)} new tasks. Pool size: {len(self.tasks)}")
return new_tasks
def save(self, path: str) -> None:
"""Save task pool to disk."""
data = [t.to_dict() for t in self.tasks]
with open(path, "w") as f:
json.dump(data, f, indent=2)
logger.info(f"Saved task pool to {path}")
def load(self, path: str) -> None:
"""Load task pool from disk."""
with open(path, "r") as f:
data = json.load(f)
self.tasks = [BioTask.from_dict(d) for d in data]
logger.info(f"Loaded task pool with {len(self.tasks)} tasks from {path}")
class ProteinTask(BioTask):
"""Protein-specific task."""
pass
class DNATask(BioTask):
"""DNA-specific task."""
pass
class RNATask(BioTask):
"""RNA-specific task."""
pass