import torch from i6modelecomm.model import i6modelecomm model = i6modelecomm.CommerceIntent.from_pretrained( "infinity6/ecomm_shop_intent_pretrained" ) # TODO: map items and remap categories. # TODO: freeze layers and train with your data. model.eval() D = 'cpu' # batch_size | seq_len = 3 itms = torch.tensor([[12, 45, 78]], dtype=torch.long).to(D) brds = torch.tensor([[3, 7, 2]], dtype=torch.long).to(D) cats = torch.tensor([[8, 8, 15]], dtype=torch.long).to(D) prcs = torch.tensor([[29.9, 35.0, 15.5]], dtype=torch.float).to(D) evts = torch.tensor([[1, 1, 2]], dtype=torch.long).to(D) # mask mask = torch.tensor([[1, 1, 1]], dtype=torch.bool).to(D) with torch.no_grad(): outputs = model( itms=itms, # items brds=brds, # brands cats=cats, # categories prcs=prcs, # prices evts=evts, # events attention_mask=mask, labels=None # inference only -- no loss computation ) # logits tem shape (B, L-1, num_itm) logits = outputs.logits print("Logits shape:", logits.shape)