vincenzocivale
commited on
Commit
·
649a4ee
1
Parent(s):
690e94b
update Model definition
Browse files- __pycache__/unified_cell_classifier.cpython-39.pyc +0 -0
- main_classifier/{main_classifier.bin → main_classifier.safetensors} +2 -2
- sub_classifiers/B_cells_classifier/{B_cells_classifier.bin → B_cells_classifier.safetensors} +2 -2
- sub_classifiers/CD4plus_T_cells_classifier/{CD4plus_T_cells_classifier.bin → CD4plus_T_cells_classifier.safetensors} +2 -2
- sub_classifiers/Myeloid_cells_classifier/{Myeloid_cells_classifier.bin → Myeloid_cells_classifier.safetensors} +2 -2
- sub_classifiers/NK_cells_classifier/NK_cells_classifier.bin +0 -3
- sub_classifiers/NK_cells_classifier/NK_cells_classifier.safetensors +3 -0
- sub_classifiers/TRAV1_2_CD8plus_T_cells_classifier/TRAV1_2_CD8plus_T_cells_classifier.bin +0 -3
- sub_classifiers/TRAV1_2_CD8plus_T_cells_classifier/TRAV1_2_CD8plus_T_cells_classifier.safetensors +3 -0
- sub_classifiers/gd_T_cells_classfier/gd_T_cells_classfier.bin +0 -3
- sub_classifiers/gd_T_cells_classfier/gd_T_cells_classfier.safetensors +3 -0
- unified_cell_classifier.py +5 -5
__pycache__/unified_cell_classifier.cpython-39.pyc
ADDED
|
Binary file (9.8 kB). View file
|
|
|
main_classifier/{main_classifier.bin → main_classifier.safetensors}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fb582f198c991edc4e69bf06fb9e4290f632b560035ca5676de2097908a03935
|
| 3 |
+
size 6891444
|
sub_classifiers/B_cells_classifier/{B_cells_classifier.bin → B_cells_classifier.safetensors}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fcbdf02bca494f0d37ec11e5479a26425318bd126d57adb7e097531dc1749e15
|
| 3 |
+
size 61528716
|
sub_classifiers/CD4plus_T_cells_classifier/{CD4plus_T_cells_classifier.bin → CD4plus_T_cells_classifier.safetensors}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0490f29c9b3afd6c75e3fc351e79c0ed763ab69aba69bb28d506ce16f38391df
|
| 3 |
+
size 61550248
|
sub_classifiers/Myeloid_cells_classifier/{Myeloid_cells_classifier.bin → Myeloid_cells_classifier.safetensors}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e22fcbad636628b6ea69bd09a8732e92913379e4dd65f19e799f8d09963892b2
|
| 3 |
+
size 61513336
|
sub_classifiers/NK_cells_classifier/NK_cells_classifier.bin
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:69cfb480ec85a7aa202eb76b543594f9d606e4d7450d18907d1656950455df10
|
| 3 |
-
size 61526594
|
|
|
|
|
|
|
|
|
|
|
|
sub_classifiers/NK_cells_classifier/NK_cells_classifier.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c7d24b289cf592458e92e27b56c9c7e48f369bb75fa7fe99d1c7f35ba2f823b2
|
| 3 |
+
size 61519488
|
sub_classifiers/TRAV1_2_CD8plus_T_cells_classifier/TRAV1_2_CD8plus_T_cells_classifier.bin
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:bcc845ffadc317a0910451a0f8170882d02b181bd87ba0a3d0a28b93b090aacd
|
| 3 |
-
size 61545634
|
|
|
|
|
|
|
|
|
|
|
|
sub_classifiers/TRAV1_2_CD8plus_T_cells_classifier/TRAV1_2_CD8plus_T_cells_classifier.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d3fb30d1c53a4c15ba723bac351c35b242528807d10fcf32e91316d5d51e9178
|
| 3 |
+
size 61537944
|
sub_classifiers/gd_T_cells_classfier/gd_T_cells_classfier.bin
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:4ab79c1c4bbce277559974482e6b6911af8869d33cc3e3557ff81b5dc1393954
|
| 3 |
-
size 61523554
|
|
|
|
|
|
|
|
|
|
|
|
sub_classifiers/gd_T_cells_classfier/gd_T_cells_classfier.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1ce3165452a3bba01d8981f2c5ae4ba0bda8a41f7f6b34682e10dabf6d9ea7a9
|
| 3 |
+
size 61516412
|
unified_cell_classifier.py
CHANGED
|
@@ -6,6 +6,7 @@ import os
|
|
| 6 |
from typing import Dict, Optional, Tuple, List
|
| 7 |
from huggingface_hub import hf_hub_download
|
| 8 |
from transformers.modeling_outputs import SequenceClassifierOutput
|
|
|
|
| 9 |
|
| 10 |
class MLPBlock(nn.Module):
|
| 11 |
def __init__(self, input_dim: int, output_dim: int, dropout_rate: float = 0.2, use_residual: bool = False):
|
|
@@ -250,8 +251,8 @@ class UnifiedCellClassifier(nn.Module):
|
|
| 250 |
model.main_classifier = model._create_classifier_from_config(main_config)
|
| 251 |
|
| 252 |
# Carica i pesi del main classifier
|
| 253 |
-
main_weights_path =
|
| 254 |
-
main_state_dict =
|
| 255 |
model.main_classifier.load_state_dict(main_state_dict, strict=False)
|
| 256 |
|
| 257 |
# Carica le label del main classifier
|
|
@@ -269,9 +270,8 @@ class UnifiedCellClassifier(nn.Module):
|
|
| 269 |
sub_config = config['sub_classifiers_config'][sub_name]
|
| 270 |
model.sub_classifiers[sub_name] = model._create_classifier_from_config(sub_config)
|
| 271 |
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
sub_state_dict = torch.load(sub_weights_path)
|
| 275 |
model.sub_classifiers[sub_name].load_state_dict(sub_state_dict, strict=False)
|
| 276 |
|
| 277 |
# Carica le label del sub-classificatore
|
|
|
|
| 6 |
from typing import Dict, Optional, Tuple, List
|
| 7 |
from huggingface_hub import hf_hub_download
|
| 8 |
from transformers.modeling_outputs import SequenceClassifierOutput
|
| 9 |
+
from safetensors.torch import load_file
|
| 10 |
|
| 11 |
class MLPBlock(nn.Module):
|
| 12 |
def __init__(self, input_dim: int, output_dim: int, dropout_rate: float = 0.2, use_residual: bool = False):
|
|
|
|
| 251 |
model.main_classifier = model._create_classifier_from_config(main_config)
|
| 252 |
|
| 253 |
# Carica i pesi del main classifier
|
| 254 |
+
main_weights_path = os.path.join(repo_id_or_path, "main_classifier/main_classifier.safetensors")
|
| 255 |
+
main_state_dict = load_file(main_weights_path)
|
| 256 |
model.main_classifier.load_state_dict(main_state_dict, strict=False)
|
| 257 |
|
| 258 |
# Carica le label del main classifier
|
|
|
|
| 270 |
sub_config = config['sub_classifiers_config'][sub_name]
|
| 271 |
model.sub_classifiers[sub_name] = model._create_classifier_from_config(sub_config)
|
| 272 |
|
| 273 |
+
sub_weights_path = os.path.join(repo_id_or_path, f"sub_classifiers/{sub_name}/{sub_name}.safetensors")
|
| 274 |
+
sub_state_dict = load_file(sub_weights_path)
|
|
|
|
| 275 |
model.sub_classifiers[sub_name].load_state_dict(sub_state_dict, strict=False)
|
| 276 |
|
| 277 |
# Carica le label del sub-classificatore
|