Spaces:
Running
Running
File size: 2,511 Bytes
9334ec6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 | """
Upload trained model artifacts to HF Hub.
Run: python -m scripts.upload_models
"""
import os
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
try:
from dotenv import load_dotenv
load_dotenv()
except ImportError:
pass
from huggingface_hub import HfApi
from app.config import HF_TOKEN, MODEL_REPO, MODEL_DIR
ARTIFACTS = [
("tft_stock.ckpt", "tft_stock.ckpt"),
("dataset_params.pt", "dataset_params.pt"),
("ddg_da.pt", "ddg_da.pt"),
]
def upload():
if not HF_TOKEN:
print("Error: HF_TOKEN not set. Add it to ml/.env or export it.")
sys.exit(1)
if not MODEL_REPO:
print("Error: HF_MODEL_REPO not set. Add it to ml/.env (e.g. yourname/stockpro-lstm).")
sys.exit(1)
api = HfApi(token=HF_TOKEN)
# Verify token is valid before doing anything
try:
whoami = api.whoami()
print(f"Authenticated as: {whoami['name']}")
except Exception as e:
print(f"Error: HF_TOKEN is invalid or expired β {e}")
sys.exit(1)
# Create repo if it doesn't exist
try:
repo_url = api.create_repo(
repo_id=MODEL_REPO,
repo_type="model",
exist_ok=True,
private=True,
)
print(f"Repo ready: {repo_url}")
except Exception as e:
print(f"Error creating repo '{MODEL_REPO}': {e}")
print("Make sure HF_MODEL_REPO is in the format 'username/repo-name'.")
sys.exit(1)
print(f"Uploading to {MODEL_REPO}...")
uploaded = 0
for local_name, repo_name in ARTIFACTS:
local_path = os.path.join(MODEL_DIR, local_name)
if not os.path.exists(local_path):
print(f" Skipping {local_name} β file not found at {local_path}")
continue
try:
api.upload_file(
path_or_fileobj=local_path,
path_in_repo=repo_name,
repo_id=MODEL_REPO,
repo_type="model",
commit_message=f"Upload {repo_name}",
)
size_mb = os.path.getsize(local_path) / 1024 / 1024
print(f" β {local_name} ({size_mb:.1f} MB) β {MODEL_REPO}/{repo_name}")
uploaded += 1
except Exception as e:
print(f" β Failed to upload {local_name}: {e}")
if uploaded == 0:
print("No files uploaded.")
sys.exit(1)
print(f"Done. {uploaded} file(s) uploaded.")
if __name__ == "__main__":
upload()
|