muruga778 commited on
Commit
6f85d0e
·
verified ·
1 Parent(s): 33402b0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -1
app.py CHANGED
@@ -1,10 +1,22 @@
1
  import json, os
2
  import numpy as np
3
  from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
 
6
 
7
- import torch
8
  import torch.nn as nn
9
  import timm
10
  from timm.data import resolve_model_data_config, create_transform
 
1
  import json, os
2
  import numpy as np
3
  from PIL import Image
4
+ import torch
5
+
6
+ def peek_out_dim(pt_path):
7
+ sd = clean_state_dict(torch.load(pt_path, map_location="cpu"))
8
+ # timm classifier keys are usually these:
9
+ for k in ["classifier.weight", "head.weight", "fc.weight"]:
10
+ if k in sd:
11
+ return sd[k].shape[0]
12
+ return None
13
+
14
+ print("✅ image checkpoint out dim:", peek_out_dim("best_scin_image.pt"))
15
+ print("✅ text checkpoint out dim :", peek_out_dim("best_scin_text.pt"))
16
+
17
 
18
 
19
 
 
20
  import torch.nn as nn
21
  import timm
22
  from timm.data import resolve_model_data_config, create_transform