File size: 3,476 Bytes
92f4bb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
"""
Model configuration and loading for time series forecasting.
Supports multiple Chronos model variants with different architectures.
"""

import torch
from chronos import Chronos2Pipeline, ChronosPipeline


class ModelConfig:
    """Configuration for available forecasting models"""
    
    CHRONOS_2_MODELS = {
        "Chronos-2 (Latest, 120M params)": {
            "model_id": "amazon/chronos-2",
            "pipeline_class": Chronos2Pipeline,
            "description": "Latest Chronos-2 model with 120M parameters"
        }
    }
    
    CHRONOS_T5_MODELS = {
        "Chronos-T5 Tiny (8M params)": {
            "model_id": "amazon/chronos-t5-tiny",
            "pipeline_class": ChronosPipeline,
            "description": "Smallest Chronos-T5 model, fastest inference"
        },
        "Chronos-T5 Mini (20M params)": {
            "model_id": "amazon/chronos-t5-mini",
            "pipeline_class": ChronosPipeline,
            "description": "Mini Chronos-T5 model"
        },
        "Chronos-T5 Small (46M params)": {
            "model_id": "amazon/chronos-t5-small",
            "pipeline_class": ChronosPipeline,
            "description": "Small Chronos-T5 model"
        },
        "Chronos-T5 Base (200M params)": {
            "model_id": "amazon/chronos-t5-base",
            "pipeline_class": ChronosPipeline,
            "description": "Base Chronos-T5 model"
        },
        "Chronos-T5 Large (710M params)": {
            "model_id": "amazon/chronos-t5-large",
            "pipeline_class": ChronosPipeline,
            "description": "Largest Chronos-T5 model, best accuracy"
        }
    }
    
    @classmethod
    def get_all_models(cls):
        """Get all available models"""
        all_models = {}
        all_models.update(cls.CHRONOS_2_MODELS)
        all_models.update(cls.CHRONOS_T5_MODELS)
        return all_models
    
    @classmethod
    def get_model_names(cls):
        """Get list of model names for dropdown"""
        return list(cls.get_all_models().keys())
    
    @classmethod
    def get_model_config(cls, model_name):
        """Get configuration for a specific model"""
        return cls.get_all_models().get(model_name)


def load_model_pipeline(model_name, device_map="cpu", dtype=torch.float32):
    """
    Load a forecasting model pipeline.
    
    Args:
        model_name: Display name of the model
        device_map: Device to load model on (default: "cpu")
        dtype: Data type for model weights (default: torch.float32)
    
    Returns:
        Loaded pipeline instance
    """
    config = ModelConfig.get_model_config(model_name)
    
    if config is None:
        raise ValueError(f"Unknown model: {model_name}")
    
    pipeline_class = config["pipeline_class"]
    model_id = config["model_id"]
    
    # Load the appropriate pipeline
    pipeline = pipeline_class.from_pretrained(
        model_id,
        device_map=device_map,
        dtype=dtype,
    )
    
    return pipeline


def get_model_info(model_name):
    """
    Get information about a model.
    
    Args:
        model_name: Display name of the model
    
    Returns:
        Dictionary with model information
    """
    config = ModelConfig.get_model_config(model_name)
    
    if config is None:
        return None
    
    return {
        "name": model_name,
        "model_id": config["model_id"],
        "description": config["description"],
        "pipeline": config["pipeline_class"].__name__
    }