File size: 3,885 Bytes
96b112e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
"""Retail World Model Inference Script
Predicts future retail sales given historical context using the trained world model.
"""
import os, pickle, numpy as np, pandas as pd, torch, torch.nn as nn
from transformers import AutoModelForSeq2SeqLM, AutoConfig
from sklearn.preprocessing import StandardScaler, LabelEncoder

# Same architecture as training
class RetailWorldModel(nn.Module):
    def __init__(self, base_model_name, context_len, pred_len, num_variates, embed_dim):
        super().__init__()
        self.config = AutoConfig.from_pretrained(base_model_name)
        self.encoder = AutoModelForSeq2SeqLM.from_pretrained(base_model_name)
        self.context_len=context_len; self.pred_len=pred_len
        self.num_variates=num_variates; self.embed_dim=embed_dim
        self.input_proj = nn.Linear(num_variates, self.config.d_model)
        self.latent_dynamics = nn.LSTM(self.config.d_model, self.config.d_model, 2, batch_first=True, dropout=0.1)
        self.mean_head = nn.Sequential(nn.Linear(self.config.d_model, embed_dim), nn.GELU(), nn.Linear(embed_dim, 1))
        self.var_head = nn.Sequential(nn.Linear(self.config.d_model, embed_dim), nn.GELU(), nn.Linear(embed_dim, 1), nn.Softplus())
    def forward(self, context):
        x = self.input_proj(context)
        enc_out = self.encoder.encoder(inputs_embeds=x, return_dict=True).last_hidden_state
        h0 = enc_out[:, -1:, :].transpose(0, 1).repeat(2, 1, 1)
        c0 = torch.zeros_like(h0)
        states=[]; curr = enc_out[:, -1:, :]
        for _ in range(self.pred_len):
            out, (h0, c0) = self.latent_dynamics(curr, (h0, c0))
            states.append(out); curr=out
        states = torch.cat(states, dim=1)
        mean = self.mean_head(states).squeeze(-1)
        var = self.var_head(states).squeeze(-1)
        return {'mean': mean, 'var': var}

def load_model_and_scaler(checkpoint_dir, base_model_name='amazon/chronos-bolt-small',
                          context_len=60, pred_len=14, num_variates=5, embed_dim=64):
    model = RetailWorldModel(base_model_name, context_len, pred_len, num_variates, embed_dim)
    ckpt = torch.load(os.path.join(checkpoint_dir, 'pytorch_model.bin'), map_location='cpu')
    model.load_state_dict(ckpt, strict=False)
    with open(os.path.join(checkpoint_dir, 'scaler.pkl'), 'rb') as f:
        scaler = pickle.load(f)
    return model, scaler

def predict(model, scaler, context_history, steps=14):
    """
    context_history: numpy array (context_len, num_variates) - last 60 days
    Returns: dict with 'mean' (actual sales), 'lower', 'upper' (90% CI)
    """
    model.eval()
    with torch.no_grad():
        ctx = torch.tensor(context_history).unsqueeze(0).float()
        out = model(ctx)
        mean = out['mean'].squeeze(0).numpy()
        std = np.sqrt(out['var'].squeeze(0).numpy())
        mean_sales = scaler.inverse_transform(mean.reshape(-1, 1)).flatten()
        std_sales = std * scaler.scale_[0]
        return {
            'mean': mean_sales,
            'lower': mean_sales - 1.645 * std_sales,
            'upper': mean_sales + 1.645 * std_sales,
        }

if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--checkpoint', required=True)
    parser.add_argument('--history', required=True, help='CSV with 60 rows of history')
    parser.add_argument('--output', default='predictions.csv')
    args = parser.parse_args()

    model, scaler = load_model_and_scaler(args.checkpoint)
    hist = pd.read_csv(args.history)
    pred = predict(model, scaler, hist.values)
    df = pd.DataFrame({'day': range(1, len(pred['mean'])+1),
                       'predicted_sales': pred['mean'],
                       'lower_90': pred['lower'],
                       'upper_90': pred['upper']})
    df.to_csv(args.output, index=False)
    print(f"Saved predictions to {args.output}")