chronos-volatility / example_usage.py
karkar69's picture
Upload ChronosVolatility model
013e105 verified
"""
Example script to load and use the exported ChronosVolatility model.
"""
import sys
from pathlib import Path
# Add parent directory to path if running as script
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
import torch
import json
from src.models.chronos import ChronosVolatility
try:
from peft import PeftModel
from transformers import AutoModelForSeq2SeqLM
_DEPS_AVAILABLE = True
except ImportError:
_DEPS_AVAILABLE = False
print("Warning: peft and transformers required for loading")
def load_exported_model(model_dir):
"""
Load exported model from directory.
Args:
model_dir: Path to exported model directory
"""
if not _DEPS_AVAILABLE:
raise ImportError("peft and transformers required. Install with: pip install peft transformers")
model_dir = Path(model_dir)
# Load config
with open(model_dir / "config.json") as f:
config = json.load(f)
# Initialize base model
print(f"Loading base model: {config['base_model']}")
base_model = AutoModelForSeq2SeqLM.from_pretrained(config['base_model'])
# Load LoRA adapters
adapter_path = model_dir / "adapter"
if adapter_path.exists() and (adapter_path / "adapter_config.json").exists():
print("Loading LoRA adapters...")
model_wrapper = PeftModel.from_pretrained(base_model, str(adapter_path))
else:
print("No adapter found, using base model")
model_wrapper = base_model
# Create ChronosVolatility wrapper (don't initialize LoRA, we'll set base manually)
chronos_model = ChronosVolatility(use_lora=False)
chronos_model.base = model_wrapper
chronos_model.hidden_dim = config['hidden_dim']
chronos_model.model_id = config['base_model']
# Load custom heads
print("Loading custom heads...")
heads = torch.load(model_dir / "heads.pt", map_location='cpu')
chronos_model.quantile_head.load_state_dict(heads['quantile_head.state_dict'])
chronos_model.value_embedding.load_state_dict(heads['value_embedding.state_dict'])
chronos_model.eval()
print("✓ Model loaded successfully")
return chronos_model
# Usage example:
# model = load_exported_model("path/to/exported/model")
# input_seq = torch.FloatTensor(squared_returns).unsqueeze(0) # (1, 60)
# with torch.no_grad():
# quantiles = model(input_seq)