File size: 7,499 Bytes
6785c47 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 | import os, pickle, numpy as np, pandas as pd, torch, torch.nn as nn
from datasets import load_dataset
from transformers import AutoModelForSeq2SeqLM, AutoConfig, TrainingArguments, Trainer, EarlyStoppingCallback, set_seed
from peft import LoraConfig, get_peft_model, TaskType
import trackio
from sklearn.preprocessing import StandardScaler, LabelEncoder
SEED=42; MODEL_NAME='amazon/chronos-bolt-small'; OUTPUT_DIR='/tmp/outputs'
HUB_MODEL_ID='superdkj/retail-world-model-v1'; DATASET_NAME='t4tiana/store-sales-time-series-forecasting'
CONTEXT_LENGTH=60; PREDICTION_LENGTH=14; NUM_VARIATES=5; EMBED_DIM=64
set_seed(SEED)
trackio.init(project='retail-world-model', run_name='retail-world-model-v1')
class RetailWorldModel(nn.Module):
def __init__(self, base_model_name, context_len, pred_len, num_variates, embed_dim):
super().__init__()
self.config = AutoConfig.from_pretrained(base_model_name)
self.encoder = AutoModelForSeq2SeqLM.from_pretrained(base_model_name)
self.context_len=context_len; self.pred_len=pred_len
self.num_variates=num_variates; self.embed_dim=embed_dim
self.input_proj = nn.Linear(num_variates, self.config.d_model)
self.latent_dynamics = nn.LSTM(self.config.d_model, self.config.d_model, 2, batch_first=True, dropout=0.1)
self.mean_head = nn.Sequential(nn.Linear(self.config.d_model, embed_dim), nn.GELU(), nn.Linear(embed_dim, 1))
self.var_head = nn.Sequential(nn.Linear(self.config.d_model, embed_dim), nn.GELU(), nn.Linear(embed_dim, 1), nn.Softplus())
def forward(self, context, target=None, return_loss=True):
x = self.input_proj(context)
enc_out = self.encoder.encoder(inputs_embeds=x, return_dict=True).last_hidden_state
h0 = enc_out[:, -1:, :].transpose(0, 1).repeat(2, 1, 1)
c0 = torch.zeros_like(h0)
states=[]; curr = enc_out[:, -1:, :]
for _ in range(self.pred_len):
out, (h0, c0) = self.latent_dynamics(curr, (h0, c0))
states.append(out); curr=out
states = torch.cat(states, dim=1)
mean = self.mean_head(states).squeeze(-1)
var = self.var_head(states).squeeze(-1)
if return_loss and target is not None:
loss = torch.mean(0.5 * torch.log(var+1e-6) + 0.5 * (target-mean)**2 / (var+1e-6))
return {'loss': loss, 'mean': mean, 'var': var}
return {'mean': mean, 'var': var}
class RetailDataset(torch.utils.data.Dataset):
def __init__(self, df, context_len=60, pred_len=14, scaler=None, fit_scaler=False):
self.context_len=context_len; self.pred_len=pred_len
df = df.copy()
df['date'] = pd.to_datetime(df['date'])
df['day_of_week'] = df['date'].dt.dayofweek / 6.0
df['month'] = df['date'].dt.month / 12.0
self.family_enc = LabelEncoder()
df['family_enc'] = self.family_enc.fit_transform(df['family'])
df['family_enc'] = df['family_enc'] / len(self.family_enc.classes_)
self.groups=[]
for _, g in df.groupby(['store_nbr', 'family']):
g = g.sort_values('date').reset_index(drop=True)
if len(g) >= context_len + pred_len: self.groups.append(g)
if scaler is None:
all_sales = np.concatenate([g['sales'].values for g in self.groups])
self.scaler = StandardScaler()
self.scaler.fit(all_sales.reshape(-1, 1))
else: self.scaler = scaler
for i, g in enumerate(self.groups):
g = g.copy()
g['sales_scaled'] = self.scaler.transform(g['sales'].values.reshape(-1, 1)).flatten()
self.groups[i] = g
self.windows=[]
for g in self.groups:
for start in range(0, len(g) - context_len - pred_len + 1, 7):
end_ctx = start + context_len
end_pred = end_ctx + pred_len
ctx = g.iloc[start:end_ctx][['sales_scaled','onpromotion','day_of_week','month','family_enc']].values.astype(np.float32)
tgt = g.iloc[end_ctx:end_pred]['sales_scaled'].values.astype(np.float32)
self.windows.append((ctx, tgt))
def __len__(self): return len(self.windows)
def __getitem__(self, idx):
ctx, tgt = self.windows[idx]
return {'context': torch.tensor(ctx), 'target': torch.tensor(tgt)}
def collate_fn(batch):
return {'context': torch.stack([b['context'] for b in batch]), 'target': torch.stack([b['target'] for b in batch])}
class RetailTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
out = model(inputs['context'], inputs['target'], return_loss=True)
loss = out['loss']
if return_outputs: return loss, out
return loss
def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
with torch.no_grad(): out = model(inputs['context'], inputs['target'], return_loss=True)
loss = out['loss']
if prediction_loss_only: return (loss, None, None)
return (loss, out['mean'], inputs['target'])
print('Loading dataset...')
ds = load_dataset(DATASET_NAME, split='train')
df = ds.to_pandas()
print(f'Rows: {len(df)}, Stores: {df[\"store_nbr\"].nunique()}, Families: {df[\"family\"].nunique()}')
df['date'] = pd.to_datetime(df['date'])
split_date = df['date'].max() - pd.Timedelta(days=90)
train_df = df[df['date'] <= split_date]
val_df = df[df['date'] > split_date]
print(f'Train: {len(train_df)}, Val: {len(val_df)}')
print('Building datasets...')
train_ds = RetailDataset(train_df, CONTEXT_LENGTH, PREDICTION_LENGTH, fit_scaler=True)
val_ds = RetailDataset(val_df, CONTEXT_LENGTH, PREDICTION_LENGTH, scaler=train_ds.scaler, fit_scaler=False)
print(f'Train windows: {len(train_ds)}, Val windows: {len(val_ds)}')
os.makedirs(OUTPUT_DIR, exist_ok=True)
with open(f'{OUTPUT_DIR}/scaler.pkl', 'wb') as f: pickle.dump(train_ds.scaler, f)
print('Initializing model...')
model = RetailWorldModel(MODEL_NAME, CONTEXT_LENGTH, PREDICTION_LENGTH, NUM_VARIATES, EMBED_DIM)
lora_cfg = LoraConfig(r=16, lora_alpha=32, target_modules=['q','v','k','o'], lora_dropout=0.05, bias='none', task_type=TaskType.SEQ_2_SEQ_LM)
model.encoder = get_peft_model(model.encoder, lora_cfg)
model.encoder.print_trainable_parameters()
args = TrainingArguments(
output_dir=OUTPUT_DIR, num_train_epochs=10, per_device_train_batch_size=32,
per_device_eval_batch_size=64, learning_rate=1e-4, weight_decay=0.01,
warmup_ratio=0.1, lr_scheduler_type='cosine', evaluation_strategy='epoch',
save_strategy='epoch', logging_strategy='steps', logging_steps=50,
logging_first_step=True, disable_tqdm=True, load_best_model_at_end=True,
metric_for_best_model='eval_loss', greater_is_better=False,
push_to_hub=True, hub_model_id=HUB_MODEL_ID, hub_strategy='every_save',
save_total_limit=2, report_to='trackio', run_name='retail-world-model-v1',
seed=SEED, dataloader_num_workers=4, gradient_accumulation_steps=2, fp16=True,
)
trainer = RetailTrainer(
model=model, args=args, train_dataset=train_ds, eval_dataset=val_ds,
data_collator=collate_fn, callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
)
print('Training...')
trainer.train()
trainer.save_model(f'{OUTPUT_DIR}/final')
eval_results = trainer.evaluate()
print(f'Final eval_loss: {eval_results[\"eval_loss\"]:.4f}')
trackio.alert(title='Training Complete', text=f'Final eval_loss={eval_results[\"eval_loss\"]:.4f}', level='INFO')
trainer.push_to_hub()
print('Done!')
|