superdkj's picture
Upload inference.py
96b112e verified
"""Retail World Model Inference Script
Predicts future retail sales given historical context using the trained world model.
"""
import os, pickle, numpy as np, pandas as pd, torch, torch.nn as nn
from transformers import AutoModelForSeq2SeqLM, AutoConfig
from sklearn.preprocessing import StandardScaler, LabelEncoder
# Same architecture as training
class RetailWorldModel(nn.Module):
def __init__(self, base_model_name, context_len, pred_len, num_variates, embed_dim):
super().__init__()
self.config = AutoConfig.from_pretrained(base_model_name)
self.encoder = AutoModelForSeq2SeqLM.from_pretrained(base_model_name)
self.context_len=context_len; self.pred_len=pred_len
self.num_variates=num_variates; self.embed_dim=embed_dim
self.input_proj = nn.Linear(num_variates, self.config.d_model)
self.latent_dynamics = nn.LSTM(self.config.d_model, self.config.d_model, 2, batch_first=True, dropout=0.1)
self.mean_head = nn.Sequential(nn.Linear(self.config.d_model, embed_dim), nn.GELU(), nn.Linear(embed_dim, 1))
self.var_head = nn.Sequential(nn.Linear(self.config.d_model, embed_dim), nn.GELU(), nn.Linear(embed_dim, 1), nn.Softplus())
def forward(self, context):
x = self.input_proj(context)
enc_out = self.encoder.encoder(inputs_embeds=x, return_dict=True).last_hidden_state
h0 = enc_out[:, -1:, :].transpose(0, 1).repeat(2, 1, 1)
c0 = torch.zeros_like(h0)
states=[]; curr = enc_out[:, -1:, :]
for _ in range(self.pred_len):
out, (h0, c0) = self.latent_dynamics(curr, (h0, c0))
states.append(out); curr=out
states = torch.cat(states, dim=1)
mean = self.mean_head(states).squeeze(-1)
var = self.var_head(states).squeeze(-1)
return {'mean': mean, 'var': var}
def load_model_and_scaler(checkpoint_dir, base_model_name='amazon/chronos-bolt-small',
context_len=60, pred_len=14, num_variates=5, embed_dim=64):
model = RetailWorldModel(base_model_name, context_len, pred_len, num_variates, embed_dim)
ckpt = torch.load(os.path.join(checkpoint_dir, 'pytorch_model.bin'), map_location='cpu')
model.load_state_dict(ckpt, strict=False)
with open(os.path.join(checkpoint_dir, 'scaler.pkl'), 'rb') as f:
scaler = pickle.load(f)
return model, scaler
def predict(model, scaler, context_history, steps=14):
"""
context_history: numpy array (context_len, num_variates) - last 60 days
Returns: dict with 'mean' (actual sales), 'lower', 'upper' (90% CI)
"""
model.eval()
with torch.no_grad():
ctx = torch.tensor(context_history).unsqueeze(0).float()
out = model(ctx)
mean = out['mean'].squeeze(0).numpy()
std = np.sqrt(out['var'].squeeze(0).numpy())
mean_sales = scaler.inverse_transform(mean.reshape(-1, 1)).flatten()
std_sales = std * scaler.scale_[0]
return {
'mean': mean_sales,
'lower': mean_sales - 1.645 * std_sales,
'upper': mean_sales + 1.645 * std_sales,
}
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--checkpoint', required=True)
parser.add_argument('--history', required=True, help='CSV with 60 rows of history')
parser.add_argument('--output', default='predictions.csv')
args = parser.parse_args()
model, scaler = load_model_and_scaler(args.checkpoint)
hist = pd.read_csv(args.history)
pred = predict(model, scaler, hist.values)
df = pd.DataFrame({'day': range(1, len(pred['mean'])+1),
'predicted_sales': pred['mean'],
'lower_90': pred['lower'],
'upper_90': pred['upper']})
df.to_csv(args.output, index=False)
print(f"Saved predictions to {args.output}")