"""Biomedical and chemical embedding extraction utilities. Production goals: - deterministic embedding cache for large corpora - safe CPU/GPU fallback - batch extraction for both text and SMILES - simple integration with training and inference pipelines """ from __future__ import annotations import hashlib import json import logging from pathlib import Path from typing import Any, Dict, Iterable, List, Optional import numpy as np try: from transformers import AutoModel, AutoTokenizer import torch except Exception: # pragma: no cover - informative fallback AutoModel = None # type: ignore AutoTokenizer = None # type: ignore torch = None # type: ignore logger = logging.getLogger("medcare_ddi.embeddings") DEFAULT_TEXT_MODELS: Dict[str, str] = { 'biobert': 'dmis-lab/biobert-base-cased-v1.1', 'pubmedbert': 'microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext', 'sapbert': 'cambridgeltl/SapBERT-from-PubMedBERT-fulltext', } DEFAULT_CHEM_MODEL = 'seyonec/ChemBERTa-zinc-base-v1' def _hash_inputs(inputs: List[str]) -> str: m = hashlib.sha256() for s in inputs: m.update(s.encode('utf-8')) return m.hexdigest()[:32] def _stable_model_name(model_name: str) -> str: return model_name.replace('/', '_').replace(':', '_') def _normalize_texts(items: Iterable[str]) -> List[str]: return [' '.join(str(x).strip().split()) for x in items] def _safe_text(value: str | None) -> str: text = str(value or '').strip() return text if text and text.lower() != 'nan' else '' class EmbeddingService: def __init__(self, cache_dir: Optional[Path] = None, device: Optional[str] = None): self.cache_dir = Path(cache_dir or Path.cwd() / "embeddings_cache") self.cache_dir.mkdir(parents=True, exist_ok=True) self._models: Dict[str, tuple[Any, Any]] = {} if device: self.device = device else: self.device = "cuda" if torch and torch.cuda.is_available() else "cpu" def resolve_model(self, model_name: str) -> str: key = model_name.strip().lower() return DEFAULT_TEXT_MODELS.get(key, model_name) def _load_model(self, model_name: str): if AutoModel is None or AutoTokenizer is None: raise RuntimeError("transformers not available; install 'transformers' and 'torch'") resolved = self.resolve_model(model_name) if resolved in self._models: return self._models[resolved] tokenizer = AutoTokenizer.from_pretrained(resolved, use_fast=True) model = AutoModel.from_pretrained(resolved) model.to(self.device) model.eval() self._models[resolved] = (tokenizer, model) return tokenizer, model def _batch_tokenize(self, tokenizer, texts: List[str], max_length: int = 128): return tokenizer(texts, padding=True, truncation=True, max_length=max_length, return_tensors='pt') def _cache_file(self, modality: str, model_name: str, values: List[str]) -> Path: key_meta = { 'modality': modality, 'model': self.resolve_model(model_name), 'n': len(values), 'hash': _hash_inputs(values), } key = hashlib.sha256(json.dumps(key_meta, sort_keys=True).encode('utf-8')).hexdigest()[:40] return self.cache_dir / f'{modality}_{_stable_model_name(self.resolve_model(model_name))}_{key}.npy' def _mean_pool(self, output, attention_mask): last = output.last_hidden_state mask = attention_mask.unsqueeze(-1) summed = (last * mask).sum(1) counts = mask.sum(1).clamp(min=1) return summed / counts def get_text_embeddings(self, texts: List[str], model_name: str, batch_size: int = 32) -> np.ndarray: """Return embeddings for list of texts using model_name. Embeddings are the mean pooled token embeddings (last hidden state). Results are cached under cache_dir/model_name/.npy """ values = _normalize_texts(texts) if not values: return np.zeros((0, 0), dtype=np.float32) cache_path = self._cache_file('text', model_name, values) if cache_path.exists(): logger.info(f"Loading cached text embeddings: {cache_path}") return np.load(cache_path) tokenizer, model = self._load_model(model_name) all_embs = [] for i in range(0, len(values), batch_size): batch = values[i:i+batch_size] toks = self._batch_tokenize(tokenizer, batch) input_ids = toks['input_ids'].to(self.device) attention_mask = toks['attention_mask'].to(self.device) with torch.no_grad(): out = model(input_ids=input_ids, attention_mask=attention_mask) pooled = self._mean_pool(out, attention_mask).cpu().numpy() all_embs.append(pooled) embs = np.vstack(all_embs).astype(np.float32) np.save(cache_path, embs) logger.info(f"Saved embeddings to cache: {cache_path}") return embs def get_smiles_embeddings(self, smiles: List[str], model_name: str = 'seyonec/ChemBERTa-zinc-base-v1', batch_size: int = 32) -> np.ndarray: """Get SMILES embeddings using a chemistry transformer (ChemBERTa). Model must be compatible with HuggingFace AutoModel. """ values = _normalize_texts(smiles) if not values: return np.zeros((0, 0), dtype=np.float32) resolved = model_name if model_name else DEFAULT_CHEM_MODEL cache_path = self._cache_file('smiles', resolved, values) if cache_path.exists(): logger.info(f"Loading cached SMILES embeddings: {cache_path}") return np.load(cache_path) tokenizer, model = self._load_model(resolved) all_embs = [] for i in range(0, len(values), batch_size): batch = values[i:i+batch_size] toks = tokenizer(batch, padding=True, truncation=True, return_tensors='pt') input_ids = toks['input_ids'].to(self.device) attention_mask = toks['attention_mask'].to(self.device) with torch.no_grad(): out = model(input_ids=input_ids, attention_mask=attention_mask) pooled = self._mean_pool(out, attention_mask).cpu().numpy() all_embs.append(pooled) embs = np.vstack(all_embs).astype(np.float32) np.save(cache_path, embs) logger.info(f"Saved SMILES embeddings to cache: {cache_path}") return embs def get_drug_profile_embeddings( self, names: List[str], active_ingredients: Optional[List[str]] = None, descriptions: Optional[List[str]] = None, smiles: Optional[List[str]] = None, text_model: str = 'pubmedbert', smiles_model: str = DEFAULT_CHEM_MODEL, batch_size: int = 32, ) -> Dict[str, np.ndarray]: """Build multimodal per-drug embeddings from name/ingredient/description/SMILES fields.""" n = len(names) active_ingredients = active_ingredients or [''] * n descriptions = descriptions or [''] * n smiles = smiles or [''] * n if not (len(active_ingredients) == len(descriptions) == len(smiles) == n): raise ValueError('All profile fields must have equal length') names_clean = [_safe_text(x) for x in names] ingredients_clean = [_safe_text(x) for x in active_ingredients] descriptions_clean = [_safe_text(x) for x in descriptions] smiles_clean = [_safe_text(x) for x in smiles] name_emb = self.get_text_embeddings(names_clean, model_name=text_model, batch_size=batch_size) ingredient_emb = self.get_text_embeddings(ingredients_clean, model_name=text_model, batch_size=batch_size) description_emb = self.get_text_embeddings(descriptions_clean, model_name=text_model, batch_size=batch_size) any_smiles = any(bool(s) for s in smiles_clean) if any_smiles: smiles_emb = self.get_smiles_embeddings(smiles_clean, model_name=smiles_model, batch_size=batch_size) else: smiles_emb = np.zeros((n, 0), dtype=np.float32) text_concat = np.hstack([name_emb, ingredient_emb, description_emb]).astype(np.float32) fused = np.hstack([text_concat, smiles_emb]).astype(np.float32) if smiles_emb.size else text_concat return { 'name': name_emb, 'active_ingredient': ingredient_emb, 'description': description_emb, 'smiles': smiles_emb, 'text_concat': text_concat, 'fused': fused, } def benchmark_embedding_models( service: EmbeddingService, texts: List[str], labels: List[int], model_names: List[str], batch_size: int = 64, ) -> Dict[str, Any]: """Quick embedding quality benchmark using 1-NN leave-one-out accuracy.""" if len(texts) != len(labels): raise ValueError('texts and labels length mismatch') if len(texts) < 3: raise ValueError('Need at least 3 samples for benchmark') y = np.array(labels, dtype=np.int64) report: Dict[str, Any] = {'num_samples': len(texts), 'models': {}} for model_name in model_names: emb = service.get_text_embeddings(texts, model_name=model_name, batch_size=batch_size) sim = emb @ emb.T norms = np.linalg.norm(emb, axis=1, keepdims=True) denom = np.clip(norms @ norms.T, 1e-9, None) cosine = sim / denom np.fill_diagonal(cosine, -1.0) nn_idx = np.argmax(cosine, axis=1) pred = y[nn_idx] score = float((pred == y).mean()) report['models'][model_name] = { 'embedding_dim': int(emb.shape[1]), 'nn_accuracy': score, } return report # Convenience initializer _default_service: Optional[EmbeddingService] = None def init_embedding_service(cache_dir: Optional[str] = None, device: Optional[str] = None) -> EmbeddingService: global _default_service if _default_service is None: _default_service = EmbeddingService(cache_dir=Path(cache_dir) if cache_dir else None, device=device) return _default_service if __name__ == '__main__': svc = init_embedding_service() try: em = svc.get_text_embeddings(['Aspirin', 'Warfarin'], 'biobert', batch_size=2) print('Embeddings shape:', em.shape) except Exception as e: print('Embedding test failed:', e)