Spaces:
Sleeping
Sleeping
Add cache-busting to refresh button
Browse files
app.py
CHANGED
|
@@ -5,7 +5,7 @@ import torch.nn.functional as F
|
|
| 5 |
import math
|
| 6 |
import time
|
| 7 |
from transformers import AutoTokenizer
|
| 8 |
-
from huggingface_hub import hf_hub_download,
|
| 9 |
|
| 10 |
class TuneableAttentionMHA(nn.Module):
|
| 11 |
def __init__(self, d: int, h: int, r: int):
|
|
@@ -85,7 +85,8 @@ class ARHead(nn.Module):
|
|
| 85 |
MODEL_REPO = "OpenTransformer/AGILLM-3-large"
|
| 86 |
|
| 87 |
def get_latest_checkpoint():
|
| 88 |
-
|
|
|
|
| 89 |
ckpts = [f for f in files if f.startswith("checkpoints/") and f.endswith(".pt")]
|
| 90 |
if not ckpts:
|
| 91 |
raise ValueError("No checkpoints found in repo")
|
|
@@ -106,14 +107,14 @@ model_state = {
|
|
| 106 |
"vocab": 0,
|
| 107 |
}
|
| 108 |
|
| 109 |
-
def load_model(ckpt_name=None):
|
| 110 |
global model_state
|
| 111 |
|
| 112 |
if ckpt_name is None:
|
| 113 |
ckpt_name = get_latest_checkpoint()
|
| 114 |
|
| 115 |
print(f"Loading checkpoint: {ckpt_name}")
|
| 116 |
-
ckpt_path = hf_hub_download(MODEL_REPO, ckpt_name)
|
| 117 |
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
|
| 118 |
cfg = ckpt["cfg"]
|
| 119 |
|
|
@@ -163,7 +164,8 @@ def check_for_updates():
|
|
| 163 |
current_step = model_state["step"]
|
| 164 |
|
| 165 |
if latest_step > current_step:
|
| 166 |
-
|
|
|
|
| 167 |
return f"✅ Updated! Step {current_step:,} → {new_step:,}"
|
| 168 |
else:
|
| 169 |
return f"Already on latest (step {current_step:,})"
|
|
|
|
| 5 |
import math
|
| 6 |
import time
|
| 7 |
from transformers import AutoTokenizer
|
| 8 |
+
from huggingface_hub import hf_hub_download, HfApi
|
| 9 |
|
| 10 |
class TuneableAttentionMHA(nn.Module):
|
| 11 |
def __init__(self, d: int, h: int, r: int):
|
|
|
|
| 85 |
MODEL_REPO = "OpenTransformer/AGILLM-3-large"
|
| 86 |
|
| 87 |
def get_latest_checkpoint():
|
| 88 |
+
api = HfApi()
|
| 89 |
+
files = api.list_repo_files(MODEL_REPO, revision="main")
|
| 90 |
ckpts = [f for f in files if f.startswith("checkpoints/") and f.endswith(".pt")]
|
| 91 |
if not ckpts:
|
| 92 |
raise ValueError("No checkpoints found in repo")
|
|
|
|
| 107 |
"vocab": 0,
|
| 108 |
}
|
| 109 |
|
| 110 |
+
def load_model(ckpt_name=None, force_download=False):
|
| 111 |
global model_state
|
| 112 |
|
| 113 |
if ckpt_name is None:
|
| 114 |
ckpt_name = get_latest_checkpoint()
|
| 115 |
|
| 116 |
print(f"Loading checkpoint: {ckpt_name}")
|
| 117 |
+
ckpt_path = hf_hub_download(MODEL_REPO, ckpt_name, force_download=force_download)
|
| 118 |
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
|
| 119 |
cfg = ckpt["cfg"]
|
| 120 |
|
|
|
|
| 164 |
current_step = model_state["step"]
|
| 165 |
|
| 166 |
if latest_step > current_step:
|
| 167 |
+
# Force fresh download, bypass cache
|
| 168 |
+
new_step, new_name = load_model(latest, force_download=True)
|
| 169 |
return f"✅ Updated! Step {current_step:,} → {new_step:,}"
|
| 170 |
else:
|
| 171 |
return f"Already on latest (step {current_step:,})"
|