vincenzocivale commited on
Commit
649a4ee
·
1 Parent(s): 690e94b

update Model definition

Browse files
__pycache__/unified_cell_classifier.cpython-39.pyc ADDED
Binary file (9.8 kB). View file
 
main_classifier/{main_classifier.bin → main_classifier.safetensors} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:79cb951a02169f43bca67e28214d08bd05b19ef7a364303a85d46937dcce4b58
3
- size 6897188
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fb582f198c991edc4e69bf06fb9e4290f632b560035ca5676de2097908a03935
3
+ size 6891444
sub_classifiers/B_cells_classifier/{B_cells_classifier.bin → B_cells_classifier.safetensors} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ae7e1c2b1e5b16348e4c2a30f21308a0008c7f1cc7133ab77380c951a0a9b462
3
- size 61535778
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fcbdf02bca494f0d37ec11e5479a26425318bd126d57adb7e097531dc1749e15
3
+ size 61528716
sub_classifiers/CD4plus_T_cells_classifier/{CD4plus_T_cells_classifier.bin → CD4plus_T_cells_classifier.safetensors} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:50d707e52f7ec3e6c88162cbd2bdbdb7edebd480118efea621c7e1ca598d1dc9
3
- size 61557666
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0490f29c9b3afd6c75e3fc351e79c0ed763ab69aba69bb28d506ce16f38391df
3
+ size 61550248
sub_classifiers/Myeloid_cells_classifier/{Myeloid_cells_classifier.bin → Myeloid_cells_classifier.safetensors} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:511f81b745539fc0451c95d5f2e4676daacf787e61cca9bd516a324afdf66d0b
3
- size 61520674
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e22fcbad636628b6ea69bd09a8732e92913379e4dd65f19e799f8d09963892b2
3
+ size 61513336
sub_classifiers/NK_cells_classifier/NK_cells_classifier.bin DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:69cfb480ec85a7aa202eb76b543594f9d606e4d7450d18907d1656950455df10
3
- size 61526594
 
 
 
 
sub_classifiers/NK_cells_classifier/NK_cells_classifier.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c7d24b289cf592458e92e27b56c9c7e48f369bb75fa7fe99d1c7f35ba2f823b2
3
+ size 61519488
sub_classifiers/TRAV1_2_CD8plus_T_cells_classifier/TRAV1_2_CD8plus_T_cells_classifier.bin DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:bcc845ffadc317a0910451a0f8170882d02b181bd87ba0a3d0a28b93b090aacd
3
- size 61545634
 
 
 
 
sub_classifiers/TRAV1_2_CD8plus_T_cells_classifier/TRAV1_2_CD8plus_T_cells_classifier.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d3fb30d1c53a4c15ba723bac351c35b242528807d10fcf32e91316d5d51e9178
3
+ size 61537944
sub_classifiers/gd_T_cells_classfier/gd_T_cells_classfier.bin DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:4ab79c1c4bbce277559974482e6b6911af8869d33cc3e3557ff81b5dc1393954
3
- size 61523554
 
 
 
 
sub_classifiers/gd_T_cells_classfier/gd_T_cells_classfier.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1ce3165452a3bba01d8981f2c5ae4ba0bda8a41f7f6b34682e10dabf6d9ea7a9
3
+ size 61516412
unified_cell_classifier.py CHANGED
@@ -6,6 +6,7 @@ import os
6
  from typing import Dict, Optional, Tuple, List
7
  from huggingface_hub import hf_hub_download
8
  from transformers.modeling_outputs import SequenceClassifierOutput
 
9
 
10
  class MLPBlock(nn.Module):
11
  def __init__(self, input_dim: int, output_dim: int, dropout_rate: float = 0.2, use_residual: bool = False):
@@ -250,8 +251,8 @@ class UnifiedCellClassifier(nn.Module):
250
  model.main_classifier = model._create_classifier_from_config(main_config)
251
 
252
  # Carica i pesi del main classifier
253
- main_weights_path = get_file_path("main_classifier/main_classifier.bin")
254
- main_state_dict = torch.load(main_weights_path)
255
  model.main_classifier.load_state_dict(main_state_dict, strict=False)
256
 
257
  # Carica le label del main classifier
@@ -269,9 +270,8 @@ class UnifiedCellClassifier(nn.Module):
269
  sub_config = config['sub_classifiers_config'][sub_name]
270
  model.sub_classifiers[sub_name] = model._create_classifier_from_config(sub_config)
271
 
272
- # Carica i pesi del sub-classificatore
273
- sub_weights_path = get_file_path(f"sub_classifiers/{sub_name}/{sub_name}.bin")
274
- sub_state_dict = torch.load(sub_weights_path)
275
  model.sub_classifiers[sub_name].load_state_dict(sub_state_dict, strict=False)
276
 
277
  # Carica le label del sub-classificatore
 
6
  from typing import Dict, Optional, Tuple, List
7
  from huggingface_hub import hf_hub_download
8
  from transformers.modeling_outputs import SequenceClassifierOutput
9
+ from safetensors.torch import load_file
10
 
11
  class MLPBlock(nn.Module):
12
  def __init__(self, input_dim: int, output_dim: int, dropout_rate: float = 0.2, use_residual: bool = False):
 
251
  model.main_classifier = model._create_classifier_from_config(main_config)
252
 
253
  # Carica i pesi del main classifier
254
+ main_weights_path = os.path.join(repo_id_or_path, "main_classifier/main_classifier.safetensors")
255
+ main_state_dict = load_file(main_weights_path)
256
  model.main_classifier.load_state_dict(main_state_dict, strict=False)
257
 
258
  # Carica le label del main classifier
 
270
  sub_config = config['sub_classifiers_config'][sub_name]
271
  model.sub_classifiers[sub_name] = model._create_classifier_from_config(sub_config)
272
 
273
+ sub_weights_path = os.path.join(repo_id_or_path, f"sub_classifiers/{sub_name}/{sub_name}.safetensors")
274
+ sub_state_dict = load_file(sub_weights_path)
 
275
  model.sub_classifiers[sub_name].load_state_dict(sub_state_dict, strict=False)
276
 
277
  # Carica le label del sub-classificatore