File size: 1,076 Bytes
f0ab69f ca057d3 f0ab69f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 | import torch
from i6modelecomm.model import i6modelecomm
model = i6modelecomm.CommerceIntent.from_pretrained(
"infinity6/ecomm_shop_intent_pretrained"
)
# TODO: map items and remap categories.
# TODO: freeze layers and train with your data.
model.eval()
D = 'cpu'
# batch_size | seq_len = 3
itms = torch.tensor([[12, 45, 78]], dtype=torch.long).to(D)
brds = torch.tensor([[3, 7, 2]], dtype=torch.long).to(D)
cats = torch.tensor([[8, 8, 15]], dtype=torch.long).to(D)
prcs = torch.tensor([[29.9, 35.0, 15.5]], dtype=torch.float).to(D)
evts = torch.tensor([[1, 1, 2]], dtype=torch.long).to(D)
# mask
mask = torch.tensor([[1, 1, 1]], dtype=torch.bool).to(D)
with torch.no_grad():
outputs = model(
itms=itms, # items
brds=brds, # brands
cats=cats, # categories
prcs=prcs, # prices
evts=evts, # events
attention_mask=mask,
labels=None # inference only -- no loss computation
)
# logits tem shape (B, L-1, num_itm)
logits = outputs.logits
print("Logits shape:", logits.shape) |