stockpro-ml / scripts /upload_models.py
will702's picture
StockPro ML backend with pytorch-forecasting TFT
9334ec6
"""
Upload trained model artifacts to HF Hub.
Run: python -m scripts.upload_models
"""
import os
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
try:
from dotenv import load_dotenv
load_dotenv()
except ImportError:
pass
from huggingface_hub import HfApi
from app.config import HF_TOKEN, MODEL_REPO, MODEL_DIR
ARTIFACTS = [
("tft_stock.ckpt", "tft_stock.ckpt"),
("dataset_params.pt", "dataset_params.pt"),
("ddg_da.pt", "ddg_da.pt"),
]
def upload():
if not HF_TOKEN:
print("Error: HF_TOKEN not set. Add it to ml/.env or export it.")
sys.exit(1)
if not MODEL_REPO:
print("Error: HF_MODEL_REPO not set. Add it to ml/.env (e.g. yourname/stockpro-lstm).")
sys.exit(1)
api = HfApi(token=HF_TOKEN)
# Verify token is valid before doing anything
try:
whoami = api.whoami()
print(f"Authenticated as: {whoami['name']}")
except Exception as e:
print(f"Error: HF_TOKEN is invalid or expired β€” {e}")
sys.exit(1)
# Create repo if it doesn't exist
try:
repo_url = api.create_repo(
repo_id=MODEL_REPO,
repo_type="model",
exist_ok=True,
private=True,
)
print(f"Repo ready: {repo_url}")
except Exception as e:
print(f"Error creating repo '{MODEL_REPO}': {e}")
print("Make sure HF_MODEL_REPO is in the format 'username/repo-name'.")
sys.exit(1)
print(f"Uploading to {MODEL_REPO}...")
uploaded = 0
for local_name, repo_name in ARTIFACTS:
local_path = os.path.join(MODEL_DIR, local_name)
if not os.path.exists(local_path):
print(f" Skipping {local_name} β€” file not found at {local_path}")
continue
try:
api.upload_file(
path_or_fileobj=local_path,
path_in_repo=repo_name,
repo_id=MODEL_REPO,
repo_type="model",
commit_message=f"Upload {repo_name}",
)
size_mb = os.path.getsize(local_path) / 1024 / 1024
print(f" βœ“ {local_name} ({size_mb:.1f} MB) β†’ {MODEL_REPO}/{repo_name}")
uploaded += 1
except Exception as e:
print(f" βœ— Failed to upload {local_name}: {e}")
if uploaded == 0:
print("No files uploaded.")
sys.exit(1)
print(f"Done. {uploaded} file(s) uploaded.")
if __name__ == "__main__":
upload()