farrell236 commited on
Commit
65a5d08
·
verified ·
1 Parent(s): 01a5dcf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -15
app.py CHANGED
@@ -1,25 +1,67 @@
1
  import os
 
 
2
  import torch
3
  from huggingface_hub import hf_hub_download
4
 
5
- MODEL_REPO_ID = "farrell236/CephVIT"
 
6
  MODEL_FILENAME = "best.pt"
7
 
8
- hf_token = os.getenv("HF_TOKEN")
9
- assert hf_token, "HF_TOKEN missing"
10
 
11
- local_path = hf_hub_download(
12
- repo_id=MODEL_REPO_ID,
13
- filename=MODEL_FILENAME,
14
- token=hf_token,
15
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- print("Downloaded to:", local_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- ckpt = torch.load(local_path, map_location="cpu")
20
- print("Type:", type(ckpt))
21
 
22
- if isinstance(ckpt, dict):
23
- print("Top-level keys:", list(ckpt.keys())[:20])
24
- if "state_dict" in ckpt and isinstance(ckpt["state_dict"], dict):
25
- print("state_dict keys:", list(ckpt["state_dict"].keys())[:20])
 
1
  import os
2
+ import traceback
3
+ import gradio as gr
4
  import torch
5
  from huggingface_hub import hf_hub_download
6
 
7
+ # Change this
8
+ MODEL_REPO_ID = "your-username/your-private-model-repo"
9
  MODEL_FILENAME = "best.pt"
10
 
 
 
11
 
12
+ def inspect_checkpoint():
13
+ try:
14
+ hf_token = os.getenv("HF_TOKEN")
15
+ if not hf_token:
16
+ return "ERROR: HF_TOKEN is missing. Add it in Space Settings -> Secrets."
17
+
18
+ local_path = hf_hub_download(
19
+ repo_id=MODEL_REPO_ID,
20
+ filename=MODEL_FILENAME,
21
+ token=hf_token,
22
+ )
23
+
24
+ lines = []
25
+ lines.append("Download successful.")
26
+ lines.append(f"Local path: {local_path}")
27
+
28
+ ckpt = torch.load(local_path, map_location="cpu")
29
+
30
+ lines.append("")
31
+ lines.append(f"Top-level object type: {type(ckpt).__name__}")
32
+
33
+ if isinstance(ckpt, dict):
34
+ top_keys = list(ckpt.keys())
35
+ lines.append(f"Top-level key count: {len(top_keys)}")
36
+ lines.append("Top-level keys:")
37
+ for k in top_keys[:50]:
38
+ lines.append(f" - {k}")
39
 
40
+ if "state_dict" in ckpt and isinstance(ckpt["state_dict"], dict):
41
+ sd_keys = list(ckpt["state_dict"].keys())
42
+ lines.append("")
43
+ lines.append(f"state_dict key count: {len(sd_keys)}")
44
+ lines.append("First 20 state_dict keys:")
45
+ for k in sd_keys[:20]:
46
+ lines.append(f" - {k}")
47
+
48
+ else:
49
+ lines.append("Checkpoint is not a dict, so no keys to print.")
50
+
51
+ return "\n".join(lines)
52
+
53
+ except Exception as e:
54
+ return f"ERROR:\n{type(e).__name__}: {e}\n\n{traceback.format_exc()}"
55
+
56
+
57
+ demo = gr.Interface(
58
+ fn=inspect_checkpoint,
59
+ inputs=None,
60
+ outputs=gr.Textbox(label="Checkpoint inspection", lines=30),
61
+ title="Private checkpoint test",
62
+ description="Checks whether best.pt can be downloaded from a private Hugging Face repo and inspected.",
63
+ )
64
 
 
 
65
 
66
+ if __name__ == "__main__":
67
+ demo.launch()