|
|
""" |
|
|
Example script to load and use the exported ChronosVolatility model. |
|
|
""" |
|
|
|
|
|
import sys |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent)) |
|
|
|
|
|
import torch |
|
|
import json |
|
|
from src.models.chronos import ChronosVolatility |
|
|
|
|
|
try: |
|
|
from peft import PeftModel |
|
|
from transformers import AutoModelForSeq2SeqLM |
|
|
_DEPS_AVAILABLE = True |
|
|
except ImportError: |
|
|
_DEPS_AVAILABLE = False |
|
|
print("Warning: peft and transformers required for loading") |
|
|
|
|
|
def load_exported_model(model_dir): |
|
|
""" |
|
|
Load exported model from directory. |
|
|
|
|
|
Args: |
|
|
model_dir: Path to exported model directory |
|
|
""" |
|
|
if not _DEPS_AVAILABLE: |
|
|
raise ImportError("peft and transformers required. Install with: pip install peft transformers") |
|
|
|
|
|
model_dir = Path(model_dir) |
|
|
|
|
|
|
|
|
with open(model_dir / "config.json") as f: |
|
|
config = json.load(f) |
|
|
|
|
|
|
|
|
print(f"Loading base model: {config['base_model']}") |
|
|
base_model = AutoModelForSeq2SeqLM.from_pretrained(config['base_model']) |
|
|
|
|
|
|
|
|
adapter_path = model_dir / "adapter" |
|
|
if adapter_path.exists() and (adapter_path / "adapter_config.json").exists(): |
|
|
print("Loading LoRA adapters...") |
|
|
model_wrapper = PeftModel.from_pretrained(base_model, str(adapter_path)) |
|
|
else: |
|
|
print("No adapter found, using base model") |
|
|
model_wrapper = base_model |
|
|
|
|
|
|
|
|
chronos_model = ChronosVolatility(use_lora=False) |
|
|
chronos_model.base = model_wrapper |
|
|
chronos_model.hidden_dim = config['hidden_dim'] |
|
|
chronos_model.model_id = config['base_model'] |
|
|
|
|
|
|
|
|
print("Loading custom heads...") |
|
|
heads = torch.load(model_dir / "heads.pt", map_location='cpu') |
|
|
chronos_model.quantile_head.load_state_dict(heads['quantile_head.state_dict']) |
|
|
chronos_model.value_embedding.load_state_dict(heads['value_embedding.state_dict']) |
|
|
|
|
|
chronos_model.eval() |
|
|
print("✓ Model loaded successfully") |
|
|
return chronos_model |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|