ddi / src /training /embeddings.py
github-actions[bot]
Deploy from GitHub Actions (fb28c05c54cf19184fc3f14f1bf3297ba5749ea2)
d29b763
"""Biomedical and chemical embedding extraction utilities.
Production goals:
- deterministic embedding cache for large corpora
- safe CPU/GPU fallback
- batch extraction for both text and SMILES
- simple integration with training and inference pipelines
"""
from __future__ import annotations
import hashlib
import json
import logging
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional
import numpy as np
try:
from transformers import AutoModel, AutoTokenizer
import torch
except Exception: # pragma: no cover - informative fallback
AutoModel = None # type: ignore
AutoTokenizer = None # type: ignore
torch = None # type: ignore
logger = logging.getLogger("medcare_ddi.embeddings")
DEFAULT_TEXT_MODELS: Dict[str, str] = {
'biobert': 'dmis-lab/biobert-base-cased-v1.1',
'pubmedbert': 'microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext',
'sapbert': 'cambridgeltl/SapBERT-from-PubMedBERT-fulltext',
}
DEFAULT_CHEM_MODEL = 'seyonec/ChemBERTa-zinc-base-v1'
def _hash_inputs(inputs: List[str]) -> str:
m = hashlib.sha256()
for s in inputs:
m.update(s.encode('utf-8'))
return m.hexdigest()[:32]
def _stable_model_name(model_name: str) -> str:
return model_name.replace('/', '_').replace(':', '_')
def _normalize_texts(items: Iterable[str]) -> List[str]:
return [' '.join(str(x).strip().split()) for x in items]
def _safe_text(value: str | None) -> str:
text = str(value or '').strip()
return text if text and text.lower() != 'nan' else ''
class EmbeddingService:
def __init__(self, cache_dir: Optional[Path] = None, device: Optional[str] = None):
self.cache_dir = Path(cache_dir or Path.cwd() / "embeddings_cache")
self.cache_dir.mkdir(parents=True, exist_ok=True)
self._models: Dict[str, tuple[Any, Any]] = {}
if device:
self.device = device
else:
self.device = "cuda" if torch and torch.cuda.is_available() else "cpu"
def resolve_model(self, model_name: str) -> str:
key = model_name.strip().lower()
return DEFAULT_TEXT_MODELS.get(key, model_name)
def _load_model(self, model_name: str):
if AutoModel is None or AutoTokenizer is None:
raise RuntimeError("transformers not available; install 'transformers' and 'torch'")
resolved = self.resolve_model(model_name)
if resolved in self._models:
return self._models[resolved]
tokenizer = AutoTokenizer.from_pretrained(resolved, use_fast=True)
model = AutoModel.from_pretrained(resolved)
model.to(self.device)
model.eval()
self._models[resolved] = (tokenizer, model)
return tokenizer, model
def _batch_tokenize(self, tokenizer, texts: List[str], max_length: int = 128):
return tokenizer(texts, padding=True, truncation=True, max_length=max_length, return_tensors='pt')
def _cache_file(self, modality: str, model_name: str, values: List[str]) -> Path:
key_meta = {
'modality': modality,
'model': self.resolve_model(model_name),
'n': len(values),
'hash': _hash_inputs(values),
}
key = hashlib.sha256(json.dumps(key_meta, sort_keys=True).encode('utf-8')).hexdigest()[:40]
return self.cache_dir / f'{modality}_{_stable_model_name(self.resolve_model(model_name))}_{key}.npy'
def _mean_pool(self, output, attention_mask):
last = output.last_hidden_state
mask = attention_mask.unsqueeze(-1)
summed = (last * mask).sum(1)
counts = mask.sum(1).clamp(min=1)
return summed / counts
def get_text_embeddings(self, texts: List[str], model_name: str, batch_size: int = 32) -> np.ndarray:
"""Return embeddings for list of texts using model_name.
Embeddings are the mean pooled token embeddings (last hidden state).
Results are cached under cache_dir/model_name/<input_hash>.npy
"""
values = _normalize_texts(texts)
if not values:
return np.zeros((0, 0), dtype=np.float32)
cache_path = self._cache_file('text', model_name, values)
if cache_path.exists():
logger.info(f"Loading cached text embeddings: {cache_path}")
return np.load(cache_path)
tokenizer, model = self._load_model(model_name)
all_embs = []
for i in range(0, len(values), batch_size):
batch = values[i:i+batch_size]
toks = self._batch_tokenize(tokenizer, batch)
input_ids = toks['input_ids'].to(self.device)
attention_mask = toks['attention_mask'].to(self.device)
with torch.no_grad():
out = model(input_ids=input_ids, attention_mask=attention_mask)
pooled = self._mean_pool(out, attention_mask).cpu().numpy()
all_embs.append(pooled)
embs = np.vstack(all_embs).astype(np.float32)
np.save(cache_path, embs)
logger.info(f"Saved embeddings to cache: {cache_path}")
return embs
def get_smiles_embeddings(self, smiles: List[str], model_name: str = 'seyonec/ChemBERTa-zinc-base-v1', batch_size: int = 32) -> np.ndarray:
"""Get SMILES embeddings using a chemistry transformer (ChemBERTa).
Model must be compatible with HuggingFace AutoModel.
"""
values = _normalize_texts(smiles)
if not values:
return np.zeros((0, 0), dtype=np.float32)
resolved = model_name if model_name else DEFAULT_CHEM_MODEL
cache_path = self._cache_file('smiles', resolved, values)
if cache_path.exists():
logger.info(f"Loading cached SMILES embeddings: {cache_path}")
return np.load(cache_path)
tokenizer, model = self._load_model(resolved)
all_embs = []
for i in range(0, len(values), batch_size):
batch = values[i:i+batch_size]
toks = tokenizer(batch, padding=True, truncation=True, return_tensors='pt')
input_ids = toks['input_ids'].to(self.device)
attention_mask = toks['attention_mask'].to(self.device)
with torch.no_grad():
out = model(input_ids=input_ids, attention_mask=attention_mask)
pooled = self._mean_pool(out, attention_mask).cpu().numpy()
all_embs.append(pooled)
embs = np.vstack(all_embs).astype(np.float32)
np.save(cache_path, embs)
logger.info(f"Saved SMILES embeddings to cache: {cache_path}")
return embs
def get_drug_profile_embeddings(
self,
names: List[str],
active_ingredients: Optional[List[str]] = None,
descriptions: Optional[List[str]] = None,
smiles: Optional[List[str]] = None,
text_model: str = 'pubmedbert',
smiles_model: str = DEFAULT_CHEM_MODEL,
batch_size: int = 32,
) -> Dict[str, np.ndarray]:
"""Build multimodal per-drug embeddings from name/ingredient/description/SMILES fields."""
n = len(names)
active_ingredients = active_ingredients or [''] * n
descriptions = descriptions or [''] * n
smiles = smiles or [''] * n
if not (len(active_ingredients) == len(descriptions) == len(smiles) == n):
raise ValueError('All profile fields must have equal length')
names_clean = [_safe_text(x) for x in names]
ingredients_clean = [_safe_text(x) for x in active_ingredients]
descriptions_clean = [_safe_text(x) for x in descriptions]
smiles_clean = [_safe_text(x) for x in smiles]
name_emb = self.get_text_embeddings(names_clean, model_name=text_model, batch_size=batch_size)
ingredient_emb = self.get_text_embeddings(ingredients_clean, model_name=text_model, batch_size=batch_size)
description_emb = self.get_text_embeddings(descriptions_clean, model_name=text_model, batch_size=batch_size)
any_smiles = any(bool(s) for s in smiles_clean)
if any_smiles:
smiles_emb = self.get_smiles_embeddings(smiles_clean, model_name=smiles_model, batch_size=batch_size)
else:
smiles_emb = np.zeros((n, 0), dtype=np.float32)
text_concat = np.hstack([name_emb, ingredient_emb, description_emb]).astype(np.float32)
fused = np.hstack([text_concat, smiles_emb]).astype(np.float32) if smiles_emb.size else text_concat
return {
'name': name_emb,
'active_ingredient': ingredient_emb,
'description': description_emb,
'smiles': smiles_emb,
'text_concat': text_concat,
'fused': fused,
}
def benchmark_embedding_models(
service: EmbeddingService,
texts: List[str],
labels: List[int],
model_names: List[str],
batch_size: int = 64,
) -> Dict[str, Any]:
"""Quick embedding quality benchmark using 1-NN leave-one-out accuracy."""
if len(texts) != len(labels):
raise ValueError('texts and labels length mismatch')
if len(texts) < 3:
raise ValueError('Need at least 3 samples for benchmark')
y = np.array(labels, dtype=np.int64)
report: Dict[str, Any] = {'num_samples': len(texts), 'models': {}}
for model_name in model_names:
emb = service.get_text_embeddings(texts, model_name=model_name, batch_size=batch_size)
sim = emb @ emb.T
norms = np.linalg.norm(emb, axis=1, keepdims=True)
denom = np.clip(norms @ norms.T, 1e-9, None)
cosine = sim / denom
np.fill_diagonal(cosine, -1.0)
nn_idx = np.argmax(cosine, axis=1)
pred = y[nn_idx]
score = float((pred == y).mean())
report['models'][model_name] = {
'embedding_dim': int(emb.shape[1]),
'nn_accuracy': score,
}
return report
# Convenience initializer
_default_service: Optional[EmbeddingService] = None
def init_embedding_service(cache_dir: Optional[str] = None, device: Optional[str] = None) -> EmbeddingService:
global _default_service
if _default_service is None:
_default_service = EmbeddingService(cache_dir=Path(cache_dir) if cache_dir else None, device=device)
return _default_service
if __name__ == '__main__':
svc = init_embedding_service()
try:
em = svc.get_text_embeddings(['Aspirin', 'Warfarin'], 'biobert', batch_size=2)
print('Embeddings shape:', em.shape)
except Exception as e:
print('Embedding test failed:', e)