""" Example script to load and use the exported ChronosVolatility model. """ import sys from pathlib import Path # Add parent directory to path if running as script sys.path.insert(0, str(Path(__file__).parent.parent.parent)) import torch import json from src.models.chronos import ChronosVolatility try: from peft import PeftModel from transformers import AutoModelForSeq2SeqLM _DEPS_AVAILABLE = True except ImportError: _DEPS_AVAILABLE = False print("Warning: peft and transformers required for loading") def load_exported_model(model_dir): """ Load exported model from directory. Args: model_dir: Path to exported model directory """ if not _DEPS_AVAILABLE: raise ImportError("peft and transformers required. Install with: pip install peft transformers") model_dir = Path(model_dir) # Load config with open(model_dir / "config.json") as f: config = json.load(f) # Initialize base model print(f"Loading base model: {config['base_model']}") base_model = AutoModelForSeq2SeqLM.from_pretrained(config['base_model']) # Load LoRA adapters adapter_path = model_dir / "adapter" if adapter_path.exists() and (adapter_path / "adapter_config.json").exists(): print("Loading LoRA adapters...") model_wrapper = PeftModel.from_pretrained(base_model, str(adapter_path)) else: print("No adapter found, using base model") model_wrapper = base_model # Create ChronosVolatility wrapper (don't initialize LoRA, we'll set base manually) chronos_model = ChronosVolatility(use_lora=False) chronos_model.base = model_wrapper chronos_model.hidden_dim = config['hidden_dim'] chronos_model.model_id = config['base_model'] # Load custom heads print("Loading custom heads...") heads = torch.load(model_dir / "heads.pt", map_location='cpu') chronos_model.quantile_head.load_state_dict(heads['quantile_head.state_dict']) chronos_model.value_embedding.load_state_dict(heads['value_embedding.state_dict']) chronos_model.eval() print("✓ Model loaded successfully") return chronos_model # Usage example: # model = load_exported_model("path/to/exported/model") # input_seq = torch.FloatTensor(squared_returns).unsqueeze(0) # (1, 60) # with torch.no_grad(): # quantiles = model(input_seq)