"""Retail World Model Inference Script Predicts future retail sales given historical context using the trained world model. """ import os, pickle, numpy as np, pandas as pd, torch, torch.nn as nn from transformers import AutoModelForSeq2SeqLM, AutoConfig from sklearn.preprocessing import StandardScaler, LabelEncoder # Same architecture as training class RetailWorldModel(nn.Module): def __init__(self, base_model_name, context_len, pred_len, num_variates, embed_dim): super().__init__() self.config = AutoConfig.from_pretrained(base_model_name) self.encoder = AutoModelForSeq2SeqLM.from_pretrained(base_model_name) self.context_len=context_len; self.pred_len=pred_len self.num_variates=num_variates; self.embed_dim=embed_dim self.input_proj = nn.Linear(num_variates, self.config.d_model) self.latent_dynamics = nn.LSTM(self.config.d_model, self.config.d_model, 2, batch_first=True, dropout=0.1) self.mean_head = nn.Sequential(nn.Linear(self.config.d_model, embed_dim), nn.GELU(), nn.Linear(embed_dim, 1)) self.var_head = nn.Sequential(nn.Linear(self.config.d_model, embed_dim), nn.GELU(), nn.Linear(embed_dim, 1), nn.Softplus()) def forward(self, context): x = self.input_proj(context) enc_out = self.encoder.encoder(inputs_embeds=x, return_dict=True).last_hidden_state h0 = enc_out[:, -1:, :].transpose(0, 1).repeat(2, 1, 1) c0 = torch.zeros_like(h0) states=[]; curr = enc_out[:, -1:, :] for _ in range(self.pred_len): out, (h0, c0) = self.latent_dynamics(curr, (h0, c0)) states.append(out); curr=out states = torch.cat(states, dim=1) mean = self.mean_head(states).squeeze(-1) var = self.var_head(states).squeeze(-1) return {'mean': mean, 'var': var} def load_model_and_scaler(checkpoint_dir, base_model_name='amazon/chronos-bolt-small', context_len=60, pred_len=14, num_variates=5, embed_dim=64): model = RetailWorldModel(base_model_name, context_len, pred_len, num_variates, embed_dim) ckpt = torch.load(os.path.join(checkpoint_dir, 'pytorch_model.bin'), map_location='cpu') model.load_state_dict(ckpt, strict=False) with open(os.path.join(checkpoint_dir, 'scaler.pkl'), 'rb') as f: scaler = pickle.load(f) return model, scaler def predict(model, scaler, context_history, steps=14): """ context_history: numpy array (context_len, num_variates) - last 60 days Returns: dict with 'mean' (actual sales), 'lower', 'upper' (90% CI) """ model.eval() with torch.no_grad(): ctx = torch.tensor(context_history).unsqueeze(0).float() out = model(ctx) mean = out['mean'].squeeze(0).numpy() std = np.sqrt(out['var'].squeeze(0).numpy()) mean_sales = scaler.inverse_transform(mean.reshape(-1, 1)).flatten() std_sales = std * scaler.scale_[0] return { 'mean': mean_sales, 'lower': mean_sales - 1.645 * std_sales, 'upper': mean_sales + 1.645 * std_sales, } if __name__ == '__main__': import argparse parser = argparse.ArgumentParser() parser.add_argument('--checkpoint', required=True) parser.add_argument('--history', required=True, help='CSV with 60 rows of history') parser.add_argument('--output', default='predictions.csv') args = parser.parse_args() model, scaler = load_model_and_scaler(args.checkpoint) hist = pd.read_csv(args.history) pred = predict(model, scaler, hist.values) df = pd.DataFrame({'day': range(1, len(pred['mean'])+1), 'predicted_sales': pred['mean'], 'lower_90': pred['lower'], 'upper_90': pred['upper']}) df.to_csv(args.output, index=False) print(f"Saved predictions to {args.output}")