| import torch | |
| from i6modelecomm.model import i6modelecomm | |
| model = i6modelecomm.CommerceIntent.from_pretrained( | |
| "infinity6/ecomm_shop_intent_pretrained" | |
| ) | |
| # TODO: map items and remap categories. | |
| # TODO: freeze layers and train with your data. | |
| model.eval() | |
| D = 'cpu' | |
| # batch_size | seq_len = 3 | |
| itms = torch.tensor([[12, 45, 78]], dtype=torch.long).to(D) | |
| brds = torch.tensor([[3, 7, 2]], dtype=torch.long).to(D) | |
| cats = torch.tensor([[8, 8, 15]], dtype=torch.long).to(D) | |
| prcs = torch.tensor([[29.9, 35.0, 15.5]], dtype=torch.float).to(D) | |
| evts = torch.tensor([[1, 1, 2]], dtype=torch.long).to(D) | |
| # mask | |
| mask = torch.tensor([[1, 1, 1]], dtype=torch.bool).to(D) | |
| with torch.no_grad(): | |
| outputs = model( | |
| itms=itms, # items | |
| brds=brds, # brands | |
| cats=cats, # categories | |
| prcs=prcs, # prices | |
| evts=evts, # events | |
| attention_mask=mask, | |
| labels=None # inference only -- no loss computation | |
| ) | |
| # logits tem shape (B, L-1, num_itm) | |
| logits = outputs.logits | |
| print("Logits shape:", logits.shape) |