Yuto2007 commited on
Commit
7bbf57b
·
verified ·
1 Parent(s): 8cd3669

Update unified_cell_classifier.py

Browse files
Files changed (1) hide show
  1. unified_cell_classifier.py +108 -23
unified_cell_classifier.py CHANGED
@@ -1,21 +1,110 @@
1
  import torch
2
  import torch.nn as nn
 
3
  import json
4
  import os
5
- from typing import Dict, Optional, Tuple
6
  from huggingface_hub import hf_hub_download
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  class UnifiedCellClassifier(nn.Module):
9
  def __init__(self,
 
10
  main_classifier_config: Dict = None,
11
  sub_classifiers_config: Dict = None,
12
- sub_classifier_names: list = None,
13
  **kwargs):
14
  """
15
  Args:
 
16
  main_classifier_config: Configurazione per il classificatore principale
17
  sub_classifiers_config: Configurazioni per i sub-classificatori
18
- sub_classifier_names: Lista nomi sub-classificatori
19
  """
20
  super().__init__()
21
 
@@ -40,23 +129,19 @@ class UnifiedCellClassifier(nn.Module):
40
  "1": "CD4plus_T_cells_classifier",
41
  "4": "Myeloid_cells_classifier",
42
  "5": "NK_cells_classifier",
43
- "7": "TRAV1_2_CD8plus_T_cells",
44
  "8": "gd_T_cells_classfier"
45
  }
46
 
47
  def _create_classifier_from_config(self, config: Dict):
48
- """Crea un classificatore dalla configurazione"""
49
- # Esempio di configurazione base - adatta secondo i tuoi modelli
50
- input_dim = config.get('input_dim', 512)
51
- hidden_dim = config.get('hidden_dim', 256)
52
- num_classes = config.get('num_classes', 10)
53
- dropout = config.get('dropout', 0.1)
54
-
55
- return nn.Sequential(
56
- nn.Linear(input_dim, hidden_dim),
57
- nn.ReLU(),
58
- nn.Dropout(dropout),
59
- nn.Linear(hidden_dim, num_classes)
60
  )
61
 
62
  def forward(self, x: torch.Tensor, return_probabilities: bool = False):
@@ -73,9 +158,10 @@ class UnifiedCellClassifier(nn.Module):
73
  if self.main_classifier is None:
74
  raise RuntimeError("Modello non caricato. Usa from_pretrained() per caricare il modello.")
75
 
76
- # Classificazione principale
77
  with torch.no_grad():
78
- main_logits = self.main_classifier(x)
 
79
  main_probs = torch.softmax(main_logits, dim=-1)
80
  main_pred = torch.argmax(main_logits, dim=-1)
81
 
@@ -94,7 +180,8 @@ class UnifiedCellClassifier(nn.Module):
94
  if sub_classifier_name in self.sub_classifiers:
95
  # Usa sub-classificatore
96
  with torch.no_grad():
97
- sub_logits = self.sub_classifiers[sub_classifier_name](x[i:i+1])
 
98
  sub_probs = torch.softmax(sub_logits, dim=-1)
99
  sub_pred = torch.argmax(sub_logits, dim=-1)
100
 
@@ -159,8 +246,7 @@ class UnifiedCellClassifier(nn.Module):
159
  model = cls(**config)
160
 
161
  # 3. Carica il classificatore principale
162
- # Crea l'architettura del main classifier
163
- main_config = config.get('main_classifier_config', {})
164
  model.main_classifier = model._create_classifier_from_config(main_config)
165
 
166
  # Carica i pesi del main classifier
@@ -180,7 +266,7 @@ class UnifiedCellClassifier(nn.Module):
180
  for sub_name in model.sub_classifier_names:
181
  try:
182
  # Crea l'architettura del sub-classificatore
183
- sub_config = config.get('sub_classifiers_config', {}).get(sub_name, {})
184
  model.sub_classifiers[sub_name] = model._create_classifier_from_config(sub_config)
185
 
186
  # Carica i pesi del sub-classificatore
@@ -246,4 +332,3 @@ class UnifiedCellClassifier(nn.Module):
246
  # Salva mapping
247
  with open(os.path.join(save_directory, "macro_to_sub.json"), 'w') as f:
248
  json.dump(self.macro_to_sub, f, indent=2)
249
-
 
1
  import torch
2
  import torch.nn as nn
3
+ import torch.nn.functional as F
4
  import json
5
  import os
6
+ from typing import Dict, Optional, Tuple, List
7
  from huggingface_hub import hf_hub_download
8
+ from transformers.modeling_outputs import SequenceClassifierOutput
9
+
10
+ class MLPBlock(nn.Module):
11
+ def __init__(self, input_dim: int, output_dim: int, dropout_rate: float = 0.2, use_residual: bool = False):
12
+ super().__init__()
13
+ self.use_residual = use_residual and (input_dim == output_dim)
14
+ self.linear = nn.Linear(input_dim, output_dim)
15
+ self.bn = nn.BatchNorm1d(output_dim)
16
+ self.activation = nn.GELU()
17
+ self.dropout = nn.Dropout(dropout_rate)
18
+
19
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
20
+ identity = x
21
+ x = self.linear(x)
22
+ x = self.bn(x)
23
+ x = self.activation(x)
24
+ x = self.dropout(x)
25
+ if self.use_residual:
26
+ x = x + identity
27
+ return x
28
+
29
+ class AdvancedMLPClassifier(nn.Module):
30
+ def __init__(
31
+ self,
32
+ input_dim: int,
33
+ hidden_dims: List[int],
34
+ output_dim: int,
35
+ dropout_rate: float = 0.2,
36
+ use_residual_in_hidden: bool = True,
37
+ loss_fn: Optional[nn.Module] = None
38
+ ):
39
+ super().__init__()
40
+ self.initial_bn = nn.BatchNorm1d(input_dim)
41
+
42
+ all_dims = [input_dim] + hidden_dims
43
+ mlp_layers = []
44
+ for i in range(len(all_dims) - 1):
45
+ mlp_layers.append(
46
+ MLPBlock(
47
+ input_dim=all_dims[i],
48
+ output_dim=all_dims[i + 1],
49
+ dropout_rate=dropout_rate,
50
+ use_residual=use_residual_in_hidden and (all_dims[i] == all_dims[i + 1])
51
+ )
52
+ )
53
+
54
+ self.hidden_network = nn.Sequential(*mlp_layers)
55
+ self.output_projection = nn.Linear(all_dims[-1], output_dim)
56
+ self.loss_fn = loss_fn if loss_fn is not None else nn.CrossEntropyLoss()
57
+ self._initialize_weights()
58
+
59
+ def forward(
60
+ self,
61
+ input_ids: torch.Tensor,
62
+ labels: Optional[torch.Tensor] = None,
63
+ attention_mask: Optional[torch.Tensor] = None,
64
+ token_type_ids: Optional[torch.Tensor] = None,
65
+ return_dict: Optional[bool] = True,
66
+ **kwargs
67
+ ) -> SequenceClassifierOutput:
68
+ if input_ids.ndim > 2:
69
+ input_ids = input_ids.view(input_ids.size(0), -1) # Flatten if necessary
70
+
71
+ x = self.initial_bn(input_ids)
72
+ x = self.hidden_network(x)
73
+ logits = self.output_projection(x)
74
+
75
+ loss = self.loss_fn(logits, labels) if labels is not None else None
76
+
77
+ if not return_dict:
78
+ return (logits, loss) if loss is not None else (logits,)
79
+
80
+ return SequenceClassifierOutput(
81
+ loss=loss,
82
+ logits=logits,
83
+ hidden_states=None,
84
+ attentions=None
85
+ )
86
+
87
+ def _initialize_weights(self):
88
+ for m in self.modules():
89
+ if isinstance(m, nn.Linear):
90
+ nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
91
+ if m.bias is not None:
92
+ nn.init.zeros_(m.bias)
93
+ elif isinstance(m, nn.BatchNorm1d):
94
+ nn.init.constant_(m.weight, 1)
95
+ nn.init.constant_(m.bias, 0)
96
 
97
  class UnifiedCellClassifier(nn.Module):
98
  def __init__(self,
99
+ sub_classifier_names: list = None,
100
  main_classifier_config: Dict = None,
101
  sub_classifiers_config: Dict = None,
 
102
  **kwargs):
103
  """
104
  Args:
105
+ sub_classifier_names: Lista nomi sub-classificatori
106
  main_classifier_config: Configurazione per il classificatore principale
107
  sub_classifiers_config: Configurazioni per i sub-classificatori
 
108
  """
109
  super().__init__()
110
 
 
129
  "1": "CD4plus_T_cells_classifier",
130
  "4": "Myeloid_cells_classifier",
131
  "5": "NK_cells_classifier",
132
+ "7": "TRAV1_2_CD8plus_T_cells_classifier",
133
  "8": "gd_T_cells_classfier"
134
  }
135
 
136
  def _create_classifier_from_config(self, config: Dict):
137
+ """Crea un AdvancedMLPClassifier dalla configurazione"""
138
+ return AdvancedMLPClassifier(
139
+ input_dim=config['input_dim'],
140
+ hidden_dims=config['hidden_dims'],
141
+ output_dim=config['output_dim'],
142
+ dropout_rate=config.get('dropout_rate', 0.2),
143
+ use_residual_in_hidden=config.get('use_residual_in_hidden', True),
144
+ loss_fn=nn.CrossEntropyLoss()
 
 
 
 
145
  )
146
 
147
  def forward(self, x: torch.Tensor, return_probabilities: bool = False):
 
158
  if self.main_classifier is None:
159
  raise RuntimeError("Modello non caricato. Usa from_pretrained() per caricare il modello.")
160
 
161
+ # Classificazione principale - usa il metodo del classificatore per avere solo logits
162
  with torch.no_grad():
163
+ main_output = self.main_classifier(x, return_dict=True)
164
+ main_logits = main_output.logits
165
  main_probs = torch.softmax(main_logits, dim=-1)
166
  main_pred = torch.argmax(main_logits, dim=-1)
167
 
 
180
  if sub_classifier_name in self.sub_classifiers:
181
  # Usa sub-classificatore
182
  with torch.no_grad():
183
+ sub_output = self.sub_classifiers[sub_classifier_name](x[i:i+1], return_dict=True)
184
+ sub_logits = sub_output.logits
185
  sub_probs = torch.softmax(sub_logits, dim=-1)
186
  sub_pred = torch.argmax(sub_logits, dim=-1)
187
 
 
246
  model = cls(**config)
247
 
248
  # 3. Carica il classificatore principale
249
+ main_config = config['main_classifier_config']
 
250
  model.main_classifier = model._create_classifier_from_config(main_config)
251
 
252
  # Carica i pesi del main classifier
 
266
  for sub_name in model.sub_classifier_names:
267
  try:
268
  # Crea l'architettura del sub-classificatore
269
+ sub_config = config['sub_classifiers_config'][sub_name]
270
  model.sub_classifiers[sub_name] = model._create_classifier_from_config(sub_config)
271
 
272
  # Carica i pesi del sub-classificatore
 
332
  # Salva mapping
333
  with open(os.path.join(save_directory, "macro_to_sub.json"), 'w') as f:
334
  json.dump(self.macro_to_sub, f, indent=2)