OpenTransformer commited on
Commit
e1f1c27
·
verified ·
1 Parent(s): 984dbfd

Add cache-busting to refresh button

Browse files
Files changed (1) hide show
  1. app.py +7 -5
app.py CHANGED
@@ -5,7 +5,7 @@ import torch.nn.functional as F
5
  import math
6
  import time
7
  from transformers import AutoTokenizer
8
- from huggingface_hub import hf_hub_download, list_repo_files
9
 
10
  class TuneableAttentionMHA(nn.Module):
11
  def __init__(self, d: int, h: int, r: int):
@@ -85,7 +85,8 @@ class ARHead(nn.Module):
85
  MODEL_REPO = "OpenTransformer/AGILLM-3-large"
86
 
87
  def get_latest_checkpoint():
88
- files = list_repo_files(MODEL_REPO)
 
89
  ckpts = [f for f in files if f.startswith("checkpoints/") and f.endswith(".pt")]
90
  if not ckpts:
91
  raise ValueError("No checkpoints found in repo")
@@ -106,14 +107,14 @@ model_state = {
106
  "vocab": 0,
107
  }
108
 
109
- def load_model(ckpt_name=None):
110
  global model_state
111
 
112
  if ckpt_name is None:
113
  ckpt_name = get_latest_checkpoint()
114
 
115
  print(f"Loading checkpoint: {ckpt_name}")
116
- ckpt_path = hf_hub_download(MODEL_REPO, ckpt_name)
117
  ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
118
  cfg = ckpt["cfg"]
119
 
@@ -163,7 +164,8 @@ def check_for_updates():
163
  current_step = model_state["step"]
164
 
165
  if latest_step > current_step:
166
- new_step, new_name = load_model(latest)
 
167
  return f"✅ Updated! Step {current_step:,} → {new_step:,}"
168
  else:
169
  return f"Already on latest (step {current_step:,})"
 
5
  import math
6
  import time
7
  from transformers import AutoTokenizer
8
+ from huggingface_hub import hf_hub_download, HfApi
9
 
10
  class TuneableAttentionMHA(nn.Module):
11
  def __init__(self, d: int, h: int, r: int):
 
85
  MODEL_REPO = "OpenTransformer/AGILLM-3-large"
86
 
87
  def get_latest_checkpoint():
88
+ api = HfApi()
89
+ files = api.list_repo_files(MODEL_REPO, revision="main")
90
  ckpts = [f for f in files if f.startswith("checkpoints/") and f.endswith(".pt")]
91
  if not ckpts:
92
  raise ValueError("No checkpoints found in repo")
 
107
  "vocab": 0,
108
  }
109
 
110
+ def load_model(ckpt_name=None, force_download=False):
111
  global model_state
112
 
113
  if ckpt_name is None:
114
  ckpt_name = get_latest_checkpoint()
115
 
116
  print(f"Loading checkpoint: {ckpt_name}")
117
+ ckpt_path = hf_hub_download(MODEL_REPO, ckpt_name, force_download=force_download)
118
  ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
119
  cfg = ckpt["cfg"]
120
 
 
164
  current_step = model_state["step"]
165
 
166
  if latest_step > current_step:
167
+ # Force fresh download, bypass cache
168
+ new_step, new_name = load_model(latest, force_download=True)
169
  return f"✅ Updated! Step {current_step:,} → {new_step:,}"
170
  else:
171
  return f"Already on latest (step {current_step:,})"