offtargeteffect's picture
Deploy mRNA Design Studio (Docker SDK)
99f834c verified
Raw
History Blame Contribute Delete
14.3 kB
"""
Model plugin system.
Users can contribute two types of models:
1. ScoringModel β€” scores an existing mRNASequence, returns a float.
2. GenerativeModel β€” generates new mRNASequences from constraints / seeds.
Models are loaded via ModelRegistry which supports:
- Local Python module (path on disk or importable package)
- Remote REST API endpoint (POST sequences β†’ scores/generations)
The API adapter wraps HTTP calls behind the same interface so the UI
code never needs to know whether a model is local or remote.
"""
from __future__ import annotations
import importlib.util
import inspect
import sys
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Type, Union
import pandas as pd
from core.models.sequence import mRNASequence
# ── Abstract base classes ────────────────────────────────────────────────────
class ScoringModel(ABC):
"""
A model that assigns a numeric score to an mRNASequence.
Implement name and score(). scores_batch() has a default list
implementation but can be overridden for vectorised inference.
"""
@property
@abstractmethod
def name(self) -> str:
"""Human-readable model name shown in the UI."""
...
@property
def description(self) -> str:
"""Optional description for the UI."""
return ""
@property
def version(self) -> str:
return "1.0"
@abstractmethod
def score(self, sequence: mRNASequence, metadata: Optional[Dict[str, Any]] = None) -> float:
"""
Score a single sequence.
Parameters
----------
sequence : mRNASequence
metadata : dict, optional
Raw database metadata attached to the sequence (raw_metadata).
Returns
-------
float
Score value. Convention: higher is better, but models may
define their own scale β€” document it in description.
"""
...
def score_batch(
self,
sequences: List[mRNASequence],
metadata: Optional[List[Optional[Dict[str, Any]]]] = None,
) -> List[float]:
"""Score a list of sequences. Override for vectorised models."""
metas = metadata or [None] * len(sequences)
return [self.score(seq, meta) for seq, meta in zip(sequences, metas)]
class GenerativeModel(ABC):
"""
A model that generates new mRNASequences from constraints or seed sequences.
"""
@property
@abstractmethod
def name(self) -> str:
...
@property
def description(self) -> str:
return ""
@property
def version(self) -> str:
return "1.0"
@abstractmethod
def generate(
self,
constraints: Dict[str, Any],
n: int = 10,
seed: Optional[mRNASequence] = None,
) -> List[mRNASequence]:
"""
Generate n sequences from the given constraints.
Parameters
----------
constraints : dict
Model-specific constraint dict (e.g. target GC, CAI, organism, etc.)
n : int
Number of sequences to generate.
seed : mRNASequence, optional
Seed sequence for mutation-based generators.
Returns
-------
List[mRNASequence]
"""
...
ModelType = Union[ScoringModel, GenerativeModel]
# ── API Adapter ──────────────────────────────────────────────────────────────
class APIScoringModel(ScoringModel):
"""
Wraps a remote REST API behind the ScoringModel interface.
Expected API contract:
POST {endpoint}/score
Body: {"sequences": [{"id": ..., "sequence": ...}, ...]}
Response: {"scores": [{"id": ..., "score": float}, ...]}
"""
def __init__(
self,
endpoint: str,
model_name: str,
api_key: Optional[str] = None,
description: str = "",
version: str = "1.0",
timeout: float = 30.0,
) -> None:
self._endpoint = endpoint.rstrip("/")
self._name = model_name
self._api_key = api_key
self._description = description
self._version = version
self._timeout = timeout
@property
def name(self) -> str:
return self._name
@property
def description(self) -> str:
return self._description
@property
def version(self) -> str:
return self._version
def _headers(self) -> Dict[str, str]:
h = {"Content-Type": "application/json"}
if self._api_key:
h["Authorization"] = f"Bearer {self._api_key}"
return h
def score(self, sequence: mRNASequence, metadata: Optional[Dict[str, Any]] = None) -> float:
results = self.score_batch([sequence], [metadata])
return results[0]
def score_batch(
self,
sequences: List[mRNASequence],
metadata: Optional[List[Optional[Dict[str, Any]]]] = None,
) -> List[float]:
import httpx
payload = {
"sequences": [
{
"id": seq.id,
"name": seq.name,
"sequence": seq.assembled_sequence,
"metadata": (metadata[i] if metadata else None),
}
for i, seq in enumerate(sequences)
]
}
response = httpx.post(
f"{self._endpoint}/score",
json=payload,
headers=self._headers(),
timeout=self._timeout,
)
response.raise_for_status()
data = response.json()
score_map = {item["id"]: item["score"] for item in data["scores"]}
return [score_map.get(seq.id, float("nan")) for seq in sequences]
class APIGenerativeModel(GenerativeModel):
"""
Wraps a remote REST API behind the GenerativeModel interface.
Expected API contract:
POST {endpoint}/generate
Body: {"constraints": {...}, "n": int, "seed_sequence": str | null}
Response: {"sequences": [{"name": ..., "cds": ..., ...}, ...]}
"""
def __init__(
self,
endpoint: str,
model_name: str,
api_key: Optional[str] = None,
description: str = "",
version: str = "1.0",
timeout: float = 60.0,
) -> None:
self._endpoint = endpoint.rstrip("/")
self._name = model_name
self._api_key = api_key
self._description = description
self._version = version
self._timeout = timeout
@property
def name(self) -> str:
return self._name
@property
def description(self) -> str:
return self._description
@property
def version(self) -> str:
return self._version
def _headers(self) -> Dict[str, str]:
h = {"Content-Type": "application/json"}
if self._api_key:
h["Authorization"] = f"Bearer {self._api_key}"
return h
def generate(
self,
constraints: Dict[str, Any],
n: int = 10,
seed: Optional[mRNASequence] = None,
) -> List[mRNASequence]:
import httpx
payload = {
"constraints": constraints,
"n": n,
"seed_sequence": seed.assembled_sequence if seed else None,
}
response = httpx.post(
f"{self._endpoint}/generate",
json=payload,
headers=self._headers(),
timeout=self._timeout,
)
response.raise_for_status()
data = response.json()
return [mRNASequence.from_dict({**item, "source": "local"}) for item in data["sequences"]]
# ── Model Registry ───────────────────────────────────────────────────────────
@dataclass
class RegisteredModel:
model: ModelType
model_type: str # "scoring" | "generative"
source: str # "local" | "api" | "builtin" | "catalog"
source_path: str = "" # file path or endpoint URL
repository: str = "" # display provenance (e.g. "github.com/ViennaRNA")
category: str = "" # model category for display
class ModelRegistry:
"""
Manages loaded scoring and generative models.
Models are registered either by loading a local Python file/module
or by configuring an API endpoint.
"""
def __init__(self) -> None:
self._models: Dict[str, RegisteredModel] = {}
# ── Loading ──────────────────────────────────────────────────────────────
def load_local(self, path: str) -> List[ModelType]:
"""
Dynamically import a Python file and register all ScoringModel /
GenerativeModel subclasses found in it.
Returns the list of loaded model instances.
"""
spec = importlib.util.spec_from_file_location("_user_model", path)
if spec is None or spec.loader is None:
raise ImportError(f"Cannot load module from: {path}")
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module) # type: ignore[union-attr]
loaded: List[ModelType] = []
for _, obj in inspect.getmembers(module, inspect.isclass):
if obj.__module__ != module.__name__:
continue
if issubclass(obj, ScoringModel) and obj is not ScoringModel:
instance = obj()
self._register(instance, "scoring", "local", path)
loaded.append(instance)
elif issubclass(obj, GenerativeModel) and obj is not GenerativeModel:
instance = obj()
self._register(instance, "generative", "local", path)
loaded.append(instance)
if not loaded:
raise ValueError(
f"No ScoringModel or GenerativeModel subclasses found in {path}."
)
return loaded
def register_api_scorer(
self,
endpoint: str,
model_name: str,
api_key: Optional[str] = None,
description: str = "",
) -> APIScoringModel:
"""Register a remote scoring API."""
model = APIScoringModel(
endpoint=endpoint,
model_name=model_name,
api_key=api_key,
description=description,
)
self._register(model, "scoring", "api", endpoint)
return model
def register_api_generator(
self,
endpoint: str,
model_name: str,
api_key: Optional[str] = None,
description: str = "",
) -> APIGenerativeModel:
"""Register a remote generative API."""
model = APIGenerativeModel(
endpoint=endpoint,
model_name=model_name,
api_key=api_key,
description=description,
)
self._register(model, "generative", "api", endpoint)
return model
# ── Running ──────────────────────────────────────────────────────────────
def run_scoring(
self,
model_name: str,
sequences: List[mRNASequence],
) -> pd.DataFrame:
"""
Run a scoring model against sequences and return a DataFrame.
Columns: id, name, score
"""
reg = self._get(model_name, "scoring")
scorer: ScoringModel = reg.model # type: ignore[assignment]
scores = scorer.score_batch(sequences)
return pd.DataFrame({
"id": [s.id for s in sequences],
"name": [s.name for s in sequences],
"score": scores,
})
def run_generation(
self,
model_name: str,
constraints: Dict[str, Any],
n: int = 10,
seed: Optional[mRNASequence] = None,
) -> List[mRNASequence]:
"""Run a generative model and return new sequences."""
reg = self._get(model_name, "generative")
generator: GenerativeModel = reg.model # type: ignore[assignment]
return generator.generate(constraints, n=n, seed=seed)
# ── Query ────────────────────────────────────────────────────────────────
@property
def scoring_models(self) -> List[RegisteredModel]:
return [r for r in self._models.values() if r.model_type == "scoring"]
@property
def generative_models(self) -> List[RegisteredModel]:
return [r for r in self._models.values() if r.model_type == "generative"]
@property
def all_models(self) -> List[RegisteredModel]:
return list(self._models.values())
def unregister(self, model_name: str) -> bool:
if model_name in self._models:
del self._models[model_name]
return True
return False
# ── Internal ─────────────────────────────────────────────────────────────
def _register(
self,
model: ModelType,
model_type: str,
source: str,
source_path: str,
) -> None:
self._models[model.name] = RegisteredModel(
model=model,
model_type=model_type,
source=source,
source_path=source_path,
)
def _get(self, name: str, expected_type: str) -> RegisteredModel:
if name not in self._models:
raise KeyError(f"Model '{name}' not found in registry.")
reg = self._models[name]
if reg.model_type != expected_type:
raise TypeError(
f"Model '{name}' is a {reg.model_type} model, not {expected_type}."
)
return reg