Ni-os commited on
Commit
5db33a8
·
verified ·
1 Parent(s): 074701b

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +60 -1
README.md CHANGED
@@ -19,4 +19,63 @@ from model_loader import load_cell_type_model
19
  model = load_cell_type_model("hepg2")
20
 
21
  # Load model for K562
22
- model = load_cell_type_model("k562")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  model = load_cell_type_model("hepg2")
20
 
21
  # Load model for K562
22
+ model = load_cell_type_model("k562")
23
+ ```
24
+ ### If you want to download weights
25
+
26
+ ```python
27
+ def get_device():
28
+ """Automatically detects available device"""
29
+ if torch.cuda.is_available():
30
+ return torch.device("cuda")
31
+ else:
32
+ return torch.device("cpu")
33
+
34
+ # Load Pre-Trained Model Weights for Human Legnet
35
+ def download_and_load_model(cell_type="k562", repo_id="Ni-os/Human_Legnet", device=None):
36
+ # Download main config
37
+ config_path = hf_hub_download(
38
+ repo_id=repo_id,
39
+ filename="config.json"
40
+ )
41
+
42
+ # Load config
43
+ with open(config_path, 'r') as f:
44
+ config = json.load(f)
45
+
46
+ # Create model
47
+ model = LegNet(
48
+ in_ch=config["in_ch"],
49
+ stem_ch=config["stem_ch"],
50
+ stem_ks=config["stem_ks"],
51
+ ef_ks=config["ef_ks"],
52
+ ef_block_sizes=config["ef_block_sizes"],
53
+ pool_sizes=config["pool_sizes"],
54
+ resize_factor=config["resize_factor"],
55
+ activation=torch.nn.SiLU
56
+ ).to(device)
57
+
58
+ # Determine which weight file to download
59
+ weight_files = {
60
+ "hepg2": "weights/hepg2_best_model_test1_val2.safetensors",
61
+ "k562": "weights/k562_best_model_test1_val2.safetensors",
62
+ "wtc11": "weights/wtc11_best_model_test1_val2.safetensors"
63
+ }
64
+
65
+ # Download weights
66
+ weights_path = hf_hub_download(
67
+ repo_id=repo_id,
68
+ filename=weight_files[cell_type.lower()]
69
+ )
70
+
71
+ # Load weights into model
72
+ state_dict = load_file(weights_path)
73
+ model.load_state_dict(state_dict)
74
+ model.eval()
75
+ print(f"✅ Model for {cell_type} loaded!")
76
+ return model
77
+
78
+ device = get_device()
79
+
80
+ print("Loading pre-trained model weights for Human Legnet")
81
+ model_human_legnet = download_and_load_model("hepg2", device = device)