File size: 8,186 Bytes
96c02e7 b4ea423 96c02e7 b4ea423 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 |
# modeling_scBloodClassifier.py
import os
from typing import List, Dict, Optional
import torch
import torch.nn as nn
from transformers import PretrainedConfig, PreTrainedModel
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers import AutoConfig, AutoModel
class MLPBlock(nn.Module):
"""Single MLP block with optional residual connection."""
def __init__(self, input_dim: int, output_dim: int, dropout_rate: float = 0.2, use_residual: bool = False):
super().__init__()
self.use_residual = use_residual and (input_dim == output_dim)
self.linear = nn.Linear(input_dim, output_dim)
self.bn = nn.BatchNorm1d(output_dim)
self.activation = nn.GELU()
self.dropout = nn.Dropout(dropout_rate)
def forward(self, x: torch.Tensor) -> torch.Tensor:
identity = x
x = self.linear(x)
x = self.bn(x)
x = self.activation(x)
x = self.dropout(x)
if self.use_residual:
x = x + identity
return x
class MLPClassifier(nn.Module):
"""MLP classifier with multiple hidden layers and optional residual connections."""
def __init__(
self,
input_dim: int,
hidden_dims: List[int],
output_dim: int,
dropout_rate: float = 0.2,
use_residual_in_hidden: bool = True,
loss_fn: Optional[nn.Module] = None
):
super().__init__()
self.initial_bn = nn.BatchNorm1d(input_dim)
all_dims = [input_dim] + hidden_dims
layers = [
MLPBlock(
input_dim=all_dims[i],
output_dim=all_dims[i + 1],
dropout_rate=dropout_rate,
use_residual=use_residual_in_hidden and (all_dims[i] == all_dims[i + 1])
)
for i in range(len(all_dims) - 1)
]
self.hidden_network = nn.Sequential(*layers)
self.output_projection = nn.Linear(all_dims[-1], output_dim)
self.loss_fn = loss_fn or nn.CrossEntropyLoss()
self._initialize_weights()
def forward(self, x: torch.Tensor, labels: Optional[torch.Tensor] = None, return_dict: bool = True):
if x.ndim > 2:
x = x.view(x.size(0), -1)
x = self.initial_bn(x)
x = self.hidden_network(x)
logits = self.output_projection(x)
loss = self.loss_fn(logits, labels) if labels is not None else None
if not return_dict:
return (logits, loss) if loss is not None else (logits,)
return SequenceClassifierOutput(loss=loss, logits=logits, hidden_states=None, attentions=None)
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
class scBloodClassifierConfig(PretrainedConfig):
"""Configuration for scBloodClassifier."""
model_type = "scBloodClassifier"
def __init__(
self,
sub_classifier_names: Optional[List[str]] = None,
main_classifier_config: Optional[Dict] = None,
sub_classifiers_config: Optional[Dict] = None,
main_labels: Optional[Dict] = None,
sub_labels: Optional[Dict] = None,
macro_to_sub: Optional[Dict] = None,
**kwargs
):
super().__init__(**kwargs)
self.sub_classifier_names = sub_classifier_names or []
self.main_classifier_config = main_classifier_config or {}
self.sub_classifiers_config = sub_classifiers_config or {}
self.main_labels = main_labels or {}
self.sub_labels = sub_labels or {}
self.macro_to_sub = macro_to_sub or {}
class scBloodClassifier(PreTrainedModel):
"""Hierarchical classifier for single-cell RNA-seq blood data."""
config_class = scBloodClassifierConfig
def __init__(self, config: scBloodClassifierConfig):
super().__init__(config)
self.config = config
# Main classifier
self.main_classifier = self._create_classifier(config.main_classifier_config)
# Sub-classifiers
self.sub_classifiers = nn.ModuleDict({
name: self._create_classifier(config.sub_classifiers_config.get(name, {}))
for name in config.sub_classifier_names
})
# Label mappings
self.main_labels = dict(config.main_labels)
self.sub_labels = dict(config.sub_labels)
self.macro_to_sub = dict(config.macro_to_sub)
self.post_init() # required by transformers
def _create_classifier(self, cfg: Dict) -> MLPClassifier:
return MLPClassifier(
input_dim=cfg['input_dim'],
hidden_dims=cfg.get('hidden_dims', []),
output_dim=cfg['output_dim'],
dropout_rate=cfg.get('dropout_rate', 0.2),
use_residual_in_hidden=cfg.get('use_residual_in_hidden', True)
)
def forward(self, x: torch.Tensor, return_dict: bool = True, **kwargs):
"""Return logits of the main classifier."""
return self.main_classifier(x, return_dict=return_dict)
def predict_labels(self, x: torch.Tensor, return_probabilities: bool = False) -> Dict[str, any]:
"""Predict hierarchical labels for a batch of inputs."""
self.eval()
with torch.no_grad():
main_out = self.main_classifier(x, return_dict=True)
main_logits = main_out.logits
main_probs = torch.softmax(main_logits, dim=-1)
main_pred = torch.argmax(main_logits, dim=-1)
final_predictions = []
sub_probs_list = [] if return_probabilities else None
for i in range(x.shape[0]):
macro_idx = str(int(main_pred[i].item()))
macro_label = self.main_labels.get(macro_idx, f"unknown_{macro_idx}")
# Check for sub-classifier
if macro_idx in self.macro_to_sub:
sub_name = self.macro_to_sub[macro_idx]
if sub_name in self.sub_classifiers:
sub_out = self.sub_classifiers[sub_name](x[i:i+1], return_dict=True)
sub_logits = sub_out.logits
sub_pred = torch.argmax(sub_logits, dim=-1)
sub_idx = str(int(sub_pred.item()))
sub_label = self.sub_labels.get(sub_name, {}).get(sub_idx, f"unknown_{sub_idx}")
final_label = f"{macro_label}_{sub_label}"
if return_probabilities:
sub_probs_list.append(torch.softmax(sub_logits, dim=-1)[0])
else:
final_label = macro_label
if return_probabilities:
sub_probs_list.append(None)
else:
final_label = macro_label
if return_probabilities:
sub_probs_list.append(None)
final_predictions.append(final_label)
out = {"final_predictions": final_predictions}
if return_probabilities:
out["macro_probabilities"] = main_probs
out["sub_probabilities"] = sub_probs_list
return out
def save_pretrained(self, save_directory: str):
"""Save model and config in Hugging Face format."""
os.makedirs(save_directory, exist_ok=True)
self.config.main_labels = self.main_labels
self.config.sub_labels = self.sub_labels
self.config.macro_to_sub = self.macro_to_sub
super().save_pretrained(save_directory)
# Optional README
readme_path = os.path.join(save_directory, "README.md")
if not os.path.exists(readme_path):
with open(readme_path, "w") as f:
f.write("# scBloodClassifier\nSaved model and config.")
AutoConfig.register("scBloodClassifier", scBloodClassifierConfig)
AutoModel.register(scBloodClassifierConfig, scBloodClassifier)
|