vincenzocivale commited on
Commit
b4ea423
·
1 Parent(s): 9099ade

Refactor: update import paths and model type for scBloodClassifier; remove unified model implementation

Browse files
class_registration.py CHANGED
@@ -1,5 +1,5 @@
1
  # Import the custom classes
2
- from .modeling_unified import scBloodClassifierConfig, scBloodClassifier
3
 
4
  # Import the necessary Auto classes from transformers
5
  from transformers import AutoConfig, AutoModel
 
1
  # Import the custom classes
2
+ from .modeling_scBloodClassifier import scBloodClassifierConfig, scBloodClassifier
3
 
4
  # Import the necessary Auto classes from transformers
5
  from transformers import AutoConfig, AutoModel
config.json CHANGED
@@ -32,7 +32,7 @@
32
  "7": "TRAV1-2- CD8+ T cells",
33
  "8": "gd T cells"
34
  },
35
- "model_type": "unified-cell-classifier",
36
  "sub_classifier_names": [
37
  "B_cells_classifier",
38
  "CD4plus_T_cells_classifier",
 
32
  "7": "TRAV1-2- CD8+ T cells",
33
  "8": "gd T cells"
34
  },
35
+ "model_type": "scBloodClassifier",
36
  "sub_classifier_names": [
37
  "B_cells_classifier",
38
  "CD4plus_T_cells_classifier",
modeling_unified.py → modeling_scBloodClassifier.py RENAMED
@@ -5,6 +5,7 @@ import torch
5
  import torch.nn as nn
6
  from transformers import PretrainedConfig, PreTrainedModel
7
  from transformers.modeling_outputs import SequenceClassifierOutput
 
8
 
9
 
10
  class MLPBlock(nn.Module):
@@ -202,3 +203,9 @@ class scBloodClassifier(PreTrainedModel):
202
  if not os.path.exists(readme_path):
203
  with open(readme_path, "w") as f:
204
  f.write("# scBloodClassifier\nSaved model and config.")
 
 
 
 
 
 
 
5
  import torch.nn as nn
6
  from transformers import PretrainedConfig, PreTrainedModel
7
  from transformers.modeling_outputs import SequenceClassifierOutput
8
+ from transformers import AutoConfig, AutoModel
9
 
10
 
11
  class MLPBlock(nn.Module):
 
203
  if not os.path.exists(readme_path):
204
  with open(readme_path, "w") as f:
205
  f.write("# scBloodClassifier\nSaved model and config.")
206
+
207
+
208
+
209
+ AutoConfig.register("scBloodClassifier", scBloodClassifierConfig)
210
+
211
+ AutoModel.register(scBloodClassifierConfig, scBloodClassifier)