File size: 2,396 Bytes
013e105
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
"""
Example script to load and use the exported ChronosVolatility model.
"""

import sys
from pathlib import Path

# Add parent directory to path if running as script
sys.path.insert(0, str(Path(__file__).parent.parent.parent))

import torch
import json
from src.models.chronos import ChronosVolatility

try:
    from peft import PeftModel
    from transformers import AutoModelForSeq2SeqLM
    _DEPS_AVAILABLE = True
except ImportError:
    _DEPS_AVAILABLE = False
    print("Warning: peft and transformers required for loading")

def load_exported_model(model_dir):
    """
    Load exported model from directory.
    
    Args:
        model_dir: Path to exported model directory
    """
    if not _DEPS_AVAILABLE:
        raise ImportError("peft and transformers required. Install with: pip install peft transformers")
    
    model_dir = Path(model_dir)
    
    # Load config
    with open(model_dir / "config.json") as f:
        config = json.load(f)
    
    # Initialize base model
    print(f"Loading base model: {config['base_model']}")
    base_model = AutoModelForSeq2SeqLM.from_pretrained(config['base_model'])
    
    # Load LoRA adapters
    adapter_path = model_dir / "adapter"
    if adapter_path.exists() and (adapter_path / "adapter_config.json").exists():
        print("Loading LoRA adapters...")
        model_wrapper = PeftModel.from_pretrained(base_model, str(adapter_path))
    else:
        print("No adapter found, using base model")
        model_wrapper = base_model
    
    # Create ChronosVolatility wrapper (don't initialize LoRA, we'll set base manually)
    chronos_model = ChronosVolatility(use_lora=False)
    chronos_model.base = model_wrapper
    chronos_model.hidden_dim = config['hidden_dim']
    chronos_model.model_id = config['base_model']
    
    # Load custom heads
    print("Loading custom heads...")
    heads = torch.load(model_dir / "heads.pt", map_location='cpu')
    chronos_model.quantile_head.load_state_dict(heads['quantile_head.state_dict'])
    chronos_model.value_embedding.load_state_dict(heads['value_embedding.state_dict'])
    
    chronos_model.eval()
    print("✓ Model loaded successfully")
    return chronos_model

# Usage example:
# model = load_exported_model("path/to/exported/model")
# input_seq = torch.FloatTensor(squared_returns).unsqueeze(0)  # (1, 60)
# with torch.no_grad():
#     quantiles = model(input_seq)