ifieryarrows commited on
Commit
db6c149
·
verified ·
1 Parent(s): e57e9d1

Sync from GitHub (tests passed)

Browse files
Dockerfile CHANGED
@@ -22,6 +22,8 @@ COPY ./adapters /code/adapters
22
  COPY ./worker /code/worker
23
  COPY ./pipelines /code/pipelines
24
  COPY ./migrations /code/migrations
 
 
25
 
26
  # Copy pre-trained model files (from Kaggle)
27
  COPY ./data/models /data/models
 
22
  COPY ./worker /code/worker
23
  COPY ./pipelines /code/pipelines
24
  COPY ./migrations /code/migrations
25
+ COPY ./deep_learning /code/deep_learning
26
+ COPY ./backtest /code/backtest
27
 
28
  # Copy pre-trained model files (from Kaggle)
29
  COPY ./data/models /data/models
deep_learning/config.py CHANGED
@@ -3,15 +3,25 @@ Central configuration for the TFT-ASRO deep learning pipeline.
3
 
4
  All hyperparameters, feature dimensions, and training settings live here
5
  so every module draws from a single source of truth.
 
 
 
 
6
  """
7
 
8
  from __future__ import annotations
9
 
 
10
  from dataclasses import dataclass, field
11
  from pathlib import Path
12
  from typing import Optional
13
 
14
 
 
 
 
 
 
15
  @dataclass(frozen=True)
16
  class EmbeddingConfig:
17
  model_name: str = "ProsusAI/finbert"
@@ -19,7 +29,7 @@ class EmbeddingConfig:
19
  pca_dim: int = 32
20
  max_token_length: int = 512
21
  batch_size: int = 64
22
- pca_model_path: str = "models/tft/pca_finbert.joblib"
23
 
24
 
25
  @dataclass(frozen=True)
@@ -86,8 +96,9 @@ class TrainingConfig:
86
  seed: int = 42
87
  num_workers: int = 0
88
  optuna_n_trials: int = 50
89
- checkpoint_dir: str = "models/tft/checkpoints"
90
- best_model_path: str = "models/tft/best_tft_asro.ckpt"
 
91
 
92
 
93
  @dataclass(frozen=True)
@@ -116,5 +127,17 @@ class TFTASROConfig:
116
 
117
 
118
  def get_tft_config() -> TFTASROConfig:
119
- """Return the default TFT-ASRO configuration."""
120
- return TFTASROConfig()
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  All hyperparameters, feature dimensions, and training settings live here
5
  so every module draws from a single source of truth.
6
+
7
+ Model paths honour the MODEL_DIR environment variable so they work both
8
+ locally (``data/models``) and inside the HF Space container
9
+ (``/data/models``).
10
  """
11
 
12
  from __future__ import annotations
13
 
14
+ import os
15
  from dataclasses import dataclass, field
16
  from pathlib import Path
17
  from typing import Optional
18
 
19
 
20
+ def _model_dir() -> str:
21
+ """Resolve the base model directory from env (same as app.settings)."""
22
+ return os.environ.get("MODEL_DIR", "/data/models")
23
+
24
+
25
  @dataclass(frozen=True)
26
  class EmbeddingConfig:
27
  model_name: str = "ProsusAI/finbert"
 
29
  pca_dim: int = 32
30
  max_token_length: int = 512
31
  batch_size: int = 64
32
+ pca_model_path: str = ""
33
 
34
 
35
  @dataclass(frozen=True)
 
96
  seed: int = 42
97
  num_workers: int = 0
98
  optuna_n_trials: int = 50
99
+ checkpoint_dir: str = ""
100
+ best_model_path: str = ""
101
+ hf_model_repo: str = "ifieryarrows/copper-mind-tft"
102
 
103
 
104
  @dataclass(frozen=True)
 
127
 
128
 
129
  def get_tft_config() -> TFTASROConfig:
130
+ """
131
+ Return the default TFT-ASRO configuration with paths resolved from
132
+ MODEL_DIR (``/data/models`` on HF Space, configurable locally).
133
+ """
134
+ base = Path(_model_dir()) / "tft"
135
+ return TFTASROConfig(
136
+ embedding=EmbeddingConfig(
137
+ pca_model_path=str(base / "pca_finbert.joblib"),
138
+ ),
139
+ training=TrainingConfig(
140
+ checkpoint_dir=str(base / "checkpoints"),
141
+ best_model_path=str(base / "best_tft_asro.ckpt"),
142
+ ),
143
+ )
deep_learning/inference/predictor.py CHANGED
@@ -42,10 +42,40 @@ class TFTPredictor:
42
  self._checkpoint_path = checkpoint_path or self.cfg.training.best_model_path
43
  self._model = None
44
  self._pca = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  @property
47
  def model(self):
48
  if self._model is None:
 
 
 
 
 
49
  from deep_learning.models.tft_copper import load_tft_model
50
  self._model = load_tft_model(self._checkpoint_path)
51
  return self._model
@@ -53,6 +83,7 @@ class TFTPredictor:
53
  @property
54
  def pca(self):
55
  if self._pca is None:
 
56
  pca_path = self.cfg.embedding.pca_model_path
57
  if Path(pca_path).exists():
58
  from deep_learning.data.embeddings import load_pca
 
42
  self._checkpoint_path = checkpoint_path or self.cfg.training.best_model_path
43
  self._model = None
44
  self._pca = None
45
+ self._hub_checked = False
46
+
47
+ def _ensure_local_artifacts(self) -> None:
48
+ """Download checkpoint from HF Hub if not present locally."""
49
+ if self._hub_checked:
50
+ return
51
+ self._hub_checked = True
52
+
53
+ if Path(self._checkpoint_path).exists():
54
+ return
55
+
56
+ try:
57
+ from deep_learning.models.hub import download_tft_artifacts
58
+
59
+ tft_dir = Path(self._checkpoint_path).parent
60
+ downloaded = download_tft_artifacts(
61
+ local_dir=tft_dir,
62
+ repo_id=self.cfg.training.hf_model_repo,
63
+ )
64
+ if downloaded:
65
+ logger.info("TFT checkpoint downloaded from HF Hub")
66
+ else:
67
+ logger.warning("TFT checkpoint not available on HF Hub")
68
+ except Exception as exc:
69
+ logger.warning("HF Hub download attempt failed: %s", exc)
70
 
71
  @property
72
  def model(self):
73
  if self._model is None:
74
+ self._ensure_local_artifacts()
75
+ if not Path(self._checkpoint_path).exists():
76
+ raise FileNotFoundError(
77
+ f"TFT checkpoint not found: {self._checkpoint_path}"
78
+ )
79
  from deep_learning.models.tft_copper import load_tft_model
80
  self._model = load_tft_model(self._checkpoint_path)
81
  return self._model
 
83
  @property
84
  def pca(self):
85
  if self._pca is None:
86
+ self._ensure_local_artifacts()
87
  pca_path = self.cfg.embedding.pca_model_path
88
  if Path(pca_path).exists():
89
  from deep_learning.data.embeddings import load_pca
deep_learning/models/hub.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HuggingFace Hub integration for TFT-ASRO model persistence.
3
+
4
+ Solves the ephemeral storage problem on HF Spaces: after training,
5
+ checkpoints are uploaded to a dedicated HF model repo; before inference,
6
+ they are downloaded if not present locally.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import logging
12
+ import os
13
+ from pathlib import Path
14
+ from typing import Optional
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ _HF_TOKEN_ENV = "HF_TOKEN"
19
+
20
+ _ARTIFACTS = [
21
+ "best_tft_asro.ckpt",
22
+ "pca_finbert.joblib",
23
+ ]
24
+
25
+
26
+ def _get_token() -> Optional[str]:
27
+ return os.environ.get(_HF_TOKEN_ENV)
28
+
29
+
30
+ def upload_tft_artifacts(
31
+ local_dir: str | Path,
32
+ repo_id: str,
33
+ commit_message: str = "Update TFT-ASRO checkpoint",
34
+ ) -> bool:
35
+ """
36
+ Upload all TFT artifacts from *local_dir* to a HuggingFace model repo.
37
+
38
+ Returns True on success, False if upload fails or token is missing.
39
+ """
40
+ token = _get_token()
41
+ if not token:
42
+ logger.warning("HF_TOKEN not set – skipping model upload to Hub")
43
+ return False
44
+
45
+ local_dir = Path(local_dir)
46
+ files_to_upload = [
47
+ local_dir / name
48
+ for name in _ARTIFACTS
49
+ if (local_dir / name).exists()
50
+ ]
51
+
52
+ if not files_to_upload:
53
+ logger.warning("No TFT artifacts found in %s", local_dir)
54
+ return False
55
+
56
+ try:
57
+ from huggingface_hub import HfApi
58
+
59
+ api = HfApi(token=token)
60
+ api.create_repo(repo_id, repo_type="model", exist_ok=True, private=True)
61
+
62
+ for fpath in files_to_upload:
63
+ api.upload_file(
64
+ path_or_fileobj=str(fpath),
65
+ path_in_repo=fpath.name,
66
+ repo_id=repo_id,
67
+ repo_type="model",
68
+ commit_message=commit_message,
69
+ )
70
+ logger.info("Uploaded %s → %s/%s", fpath.name, repo_id, fpath.name)
71
+
72
+ return True
73
+
74
+ except Exception as exc:
75
+ logger.error("HF Hub upload failed: %s", exc)
76
+ return False
77
+
78
+
79
+ def download_tft_artifacts(
80
+ local_dir: str | Path,
81
+ repo_id: str,
82
+ ) -> bool:
83
+ """
84
+ Download TFT artifacts from HuggingFace Hub to *local_dir*.
85
+
86
+ Skips files that already exist locally.
87
+ Returns True if at least the checkpoint was retrieved.
88
+ """
89
+ token = _get_token()
90
+ local_dir = Path(local_dir)
91
+ local_dir.mkdir(parents=True, exist_ok=True)
92
+
93
+ ckpt_path = local_dir / "best_tft_asro.ckpt"
94
+ if ckpt_path.exists():
95
+ logger.debug("TFT checkpoint already present locally: %s", ckpt_path)
96
+ return True
97
+
98
+ try:
99
+ from huggingface_hub import hf_hub_download
100
+
101
+ for name in _ARTIFACTS:
102
+ dest = local_dir / name
103
+ if dest.exists():
104
+ continue
105
+ try:
106
+ hf_hub_download(
107
+ repo_id=repo_id,
108
+ filename=name,
109
+ local_dir=str(local_dir),
110
+ token=token,
111
+ )
112
+ logger.info("Downloaded %s/%s → %s", repo_id, name, dest)
113
+ except Exception:
114
+ logger.debug("Artifact %s not found in %s (may not exist yet)", name, repo_id)
115
+
116
+ return ckpt_path.exists()
117
+
118
+ except ImportError:
119
+ logger.warning("huggingface_hub not installed – cannot download model")
120
+ return False
121
+ except Exception as exc:
122
+ logger.warning("HF Hub download failed: %s", exc)
123
+ return False
deep_learning/training/trainer.py CHANGED
@@ -183,6 +183,23 @@ def train_tft_model(
183
 
184
  _persist_tft_metadata(cfg.feature_store.target_symbol, result)
185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  return result
187
 
188
 
 
183
 
184
  _persist_tft_metadata(cfg.feature_store.target_symbol, result)
185
 
186
+ # ---- 10. Upload to HF Hub (for persistence across HF Space rebuilds) ----
187
+ try:
188
+ from deep_learning.models.hub import upload_tft_artifacts
189
+
190
+ tft_dir = final_path.parent
191
+ uploaded = upload_tft_artifacts(
192
+ local_dir=tft_dir,
193
+ repo_id=cfg.training.hf_model_repo,
194
+ commit_message=f"TFT-ASRO checkpoint (val_loss={trainer.checkpoint_callback.best_model_score:.4f})"
195
+ if trainer.checkpoint_callback.best_model_score
196
+ else "TFT-ASRO checkpoint",
197
+ )
198
+ result["hub_uploaded"] = uploaded
199
+ except Exception as exc:
200
+ logger.warning("HF Hub upload skipped: %s", exc)
201
+ result["hub_uploaded"] = False
202
+
203
  return result
204
 
205
 
worker/tasks.py CHANGED
@@ -575,9 +575,10 @@ async def _execute_pipeline_stages_v2(
575
  logger.info(f"[run_id={run_id}] Stage 5.5: TFT-ASRO snapshot")
576
  try:
577
  from deep_learning.inference.predictor import generate_tft_analysis
 
578
  from pathlib import Path
579
 
580
- ckpt = Path("models/tft/best_tft_asro.ckpt")
581
  if ckpt.exists():
582
  tft_report = generate_tft_analysis(session, "HG=F")
583
 
 
575
  logger.info(f"[run_id={run_id}] Stage 5.5: TFT-ASRO snapshot")
576
  try:
577
  from deep_learning.inference.predictor import generate_tft_analysis
578
+ from deep_learning.config import get_tft_config
579
  from pathlib import Path
580
 
581
+ ckpt = Path(get_tft_config().training.best_model_path)
582
  if ckpt.exists():
583
  tft_report = generate_tft_analysis(session, "HG=F")
584