File size: 1,076 Bytes
f0ab69f
 
 
 
 
ca057d3
f0ab69f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch

from i6modelecomm.model import i6modelecomm

model = i6modelecomm.CommerceIntent.from_pretrained(
    "infinity6/ecomm_shop_intent_pretrained"
)

# TODO: map items and remap categories.

# TODO: freeze layers and train with your data.

model.eval()

D = 'cpu'

# batch_size | seq_len = 3
itms = torch.tensor([[12, 45, 78]], dtype=torch.long).to(D)
brds = torch.tensor([[3, 7, 2]], dtype=torch.long).to(D)
cats = torch.tensor([[8, 8, 15]], dtype=torch.long).to(D)
prcs = torch.tensor([[29.9, 35.0, 15.5]], dtype=torch.float).to(D)
evts = torch.tensor([[1, 1, 2]], dtype=torch.long).to(D)

# mask
mask = torch.tensor([[1, 1, 1]], dtype=torch.bool).to(D)

with torch.no_grad():
    outputs = model(
        itms=itms,      # items
        brds=brds,      # brands
        cats=cats,      # categories
        prcs=prcs,      # prices
        evts=evts,      # events
        attention_mask=mask,
        labels=None     # inference only -- no loss computation
    )

# logits tem shape (B, L-1, num_itm)
logits = outputs.logits

print("Logits shape:", logits.shape)