File size: 2,396 Bytes
013e105 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
"""
Example script to load and use the exported ChronosVolatility model.
"""
import sys
from pathlib import Path
# Add parent directory to path if running as script
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
import torch
import json
from src.models.chronos import ChronosVolatility
try:
from peft import PeftModel
from transformers import AutoModelForSeq2SeqLM
_DEPS_AVAILABLE = True
except ImportError:
_DEPS_AVAILABLE = False
print("Warning: peft and transformers required for loading")
def load_exported_model(model_dir):
"""
Load exported model from directory.
Args:
model_dir: Path to exported model directory
"""
if not _DEPS_AVAILABLE:
raise ImportError("peft and transformers required. Install with: pip install peft transformers")
model_dir = Path(model_dir)
# Load config
with open(model_dir / "config.json") as f:
config = json.load(f)
# Initialize base model
print(f"Loading base model: {config['base_model']}")
base_model = AutoModelForSeq2SeqLM.from_pretrained(config['base_model'])
# Load LoRA adapters
adapter_path = model_dir / "adapter"
if adapter_path.exists() and (adapter_path / "adapter_config.json").exists():
print("Loading LoRA adapters...")
model_wrapper = PeftModel.from_pretrained(base_model, str(adapter_path))
else:
print("No adapter found, using base model")
model_wrapper = base_model
# Create ChronosVolatility wrapper (don't initialize LoRA, we'll set base manually)
chronos_model = ChronosVolatility(use_lora=False)
chronos_model.base = model_wrapper
chronos_model.hidden_dim = config['hidden_dim']
chronos_model.model_id = config['base_model']
# Load custom heads
print("Loading custom heads...")
heads = torch.load(model_dir / "heads.pt", map_location='cpu')
chronos_model.quantile_head.load_state_dict(heads['quantile_head.state_dict'])
chronos_model.value_embedding.load_state_dict(heads['value_embedding.state_dict'])
chronos_model.eval()
print("✓ Model loaded successfully")
return chronos_model
# Usage example:
# model = load_exported_model("path/to/exported/model")
# input_seq = torch.FloatTensor(squared_returns).unsqueeze(0) # (1, 60)
# with torch.no_grad():
# quantiles = model(input_seq)
|