File size: 4,185 Bytes
c40c447
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
"""
Implementaci贸n concreta del modelo Chronos-2.

Este m贸dulo implementa la interfaz IForecastModel usando Chronos2Pipeline,
aplicando el principio DIP (Dependency Inversion Principle).
"""

from typing import List, Dict, Any
import pandas as pd
from chronos import Chronos2Pipeline

from app.domain.interfaces.forecast_model import IForecastModel
from app.utils.logger import setup_logger

logger = setup_logger(__name__)


class ChronosModel(IForecastModel):
    """
    Implementaci贸n concreta de IForecastModel usando Chronos-2.
    
    Esta clase puede ser reemplazada por otra implementaci贸n
    (Prophet, ARIMA, etc.) sin modificar el resto del c贸digo,
    gracias al principio DIP.
    
    Attributes:
        model_id: ID del modelo en HuggingFace
        device_map: Dispositivo para inferencia (cpu/cuda)
        pipeline: Pipeline de Chronos2
    """
    
    def __init__(self, model_id: str = "amazon/chronos-2", device_map: str = "cpu"):
        """
        Inicializa el modelo Chronos-2.
        
        Args:
            model_id: ID del modelo en HuggingFace
            device_map: Dispositivo para inferencia (cpu/cuda)
        """
        self.model_id = model_id
        self.device_map = device_map
        
        logger.info(f"Loading Chronos model: {model_id} on {device_map}")
        
        try:
            self.pipeline = Chronos2Pipeline.from_pretrained(
                model_id,
                device_map=device_map
            )
            logger.info("Chronos model loaded successfully")
        except Exception as e:
            logger.error(f"Failed to load Chronos model: {e}")
            raise
    
    def predict(
        self,
        context_df: pd.DataFrame,
        prediction_length: int,
        quantile_levels: List[float],
        **kwargs
    ) -> pd.DataFrame:
        """
        Genera pron贸sticos probabil铆sticos usando Chronos-2.
        
        Args:
            context_df: DataFrame con columnas [id, timestamp, target]
            prediction_length: Horizonte de predicci贸n
            quantile_levels: Cuantiles a calcular (ej: [0.1, 0.5, 0.9])
            **kwargs: Argumentos adicionales para el pipeline
            
        Returns:
            DataFrame con pron贸sticos y cuantiles
            
        Raises:
            ValueError: Si el context_df no tiene el formato correcto
            RuntimeError: Si falla la inferencia
        """
        logger.debug(
            f"Predicting {prediction_length} steps with "
            f"{len(quantile_levels)} quantiles"
        )
        
        # Validar formato del DataFrame
        required_cols = {"id", "timestamp", "target"}
        if not required_cols.issubset(context_df.columns):
            raise ValueError(
                f"context_df debe tener columnas: {required_cols}. "
                f"Encontradas: {set(context_df.columns)}"
            )
        
        try:
            # Realizar predicci贸n
            pred_df = self.pipeline.predict_df(
                context_df,
                prediction_length=prediction_length,
                quantile_levels=quantile_levels,
                id_column="id",
                timestamp_column="timestamp",
                target="target",
                **kwargs
            )
            
            # Ordenar resultado
            result = pred_df.sort_values(["id", "timestamp"])
            
            logger.debug(f"Prediction completed: {len(result)} rows")
            return result
            
        except Exception as e:
            logger.error(f"Prediction failed: {e}")
            raise RuntimeError(f"Error en predicci贸n: {e}") from e
    
    def get_model_info(self) -> Dict[str, Any]:
        """
        Retorna informaci贸n del modelo.
        
        Returns:
            Diccionario con informaci贸n del modelo
        """
        return {
            "type": "Chronos2",
            "model_id": self.model_id,
            "device": self.device_map,
            "provider": "Amazon",
            "version": "2.0"
        }
    
    def __repr__(self) -> str:
        return f"ChronosModel(model_id='{self.model_id}', device='{self.device_map}')"