demand_forecast / informer_model.py
yousaf1's picture
Upload informer model script
7a08941 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
# A basic Attention Layer
class ProbAttention(nn.Module):
def __init__(self):
super(ProbAttention, self).__init__()
def forward(self, queries, keys, values):
scores = torch.matmul(queries, keys.transpose(-2, -1)) / queries.size(-1) ** 0.5
attn = torch.softmax(scores, dim=-1)
context = torch.matmul(attn, values)
return context
# Encoder Layer
class EncoderLayer(nn.Module):
def __init__(self, d_model, n_heads, d_ff=2048, dropout=0.1):
super(EncoderLayer, self).__init__()
self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
self.linear1 = nn.Linear(d_model, d_ff)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(d_ff, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, src):
src2 = self.self_attn(src, src, src)[0]
src = src + self.dropout(src2)
src = self.norm1(src)
src2 = self.linear2(F.relu(self.linear1(src)))
src = src + self.dropout(src2)
src = self.norm2(src)
return src
# Decoder Layer
class DecoderLayer(nn.Module):
def __init__(self, d_model, n_heads, d_ff=2048, dropout=0.1):
super(DecoderLayer, self).__init__()
self.self_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
self.multihead_attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout)
self.linear1 = nn.Linear(d_model, d_ff)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(d_ff, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
def forward(self, tgt, memory):
tgt2 = self.self_attn(tgt, tgt, tgt)[0]
tgt = tgt + self.dropout(tgt2)
tgt = self.norm1(tgt)
tgt2 = self.multihead_attn(tgt, memory, memory)[0]
tgt = tgt + self.dropout(tgt2)
tgt = self.norm2(tgt)
tgt2 = self.linear2(F.relu(self.linear1(tgt)))
tgt = tgt + self.dropout(tgt2)
tgt = self.norm3(tgt)
return tgt
# Full Informer Model
class Informer(nn.Module):
def __init__(self, enc_in, dec_in, c_out, seq_len, label_len, out_len,
d_model=512, n_heads=8, e_layers=2, d_layers=1, dropout=0.1):
super(Informer, self).__init__()
self.seq_len = seq_len
self.label_len = label_len
self.out_len = out_len
# Embedding
self.enc_embedding = nn.Linear(enc_in, d_model)
self.dec_embedding = nn.Linear(dec_in, d_model)
# Encoder
self.encoder = nn.ModuleList([
EncoderLayer(d_model, n_heads, dropout=dropout)
for _ in range(e_layers)
])
# Decoder
self.decoder = nn.ModuleList([
DecoderLayer(d_model, n_heads, dropout=dropout)
for _ in range(d_layers)
])
# Final projection
self.projection = nn.Linear(d_model, c_out)
def forward(self, enc_inp, dec_inp):
# Embedding
enc_out = self.enc_embedding(enc_inp)
dec_out = self.dec_embedding(dec_inp)
# Encoder
enc_out = enc_out.permute(1, 0, 2) # [seq_len, batch, d_model]
for layer in self.encoder:
enc_out = layer(enc_out)
memory = enc_out
# Decoder
dec_out = dec_out.permute(1, 0, 2)
for layer in self.decoder:
dec_out = layer(dec_out, memory)
# Final projection
dec_out = dec_out.permute(1, 0, 2)
output = self.projection(dec_out)
return output[:, -self.out_len:, :]