Spaces:
Build error
Build error
Edit app.py
Browse files
app.py
CHANGED
|
@@ -1,5 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import streamlit as st
|
| 2 |
|
| 3 |
-
x = st.slider('Select a value')
|
| 4 |
-
st.write(x, 'squared is', x * x)
|
| 5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import tensorflow as tf
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import time
|
| 6 |
+
|
| 7 |
+
from recommenders.models.sasrec.model import SASREC
|
| 8 |
+
from tabulate import tabulate
|
| 9 |
+
|
| 10 |
import streamlit as st
|
| 11 |
|
|
|
|
|
|
|
| 12 |
|
| 13 |
+
class SASREC_Vessl(SASREC):
|
| 14 |
+
def __init__(self, **kwargs):
|
| 15 |
+
super().__init__(**kwargs)
|
| 16 |
+
|
| 17 |
+
def predict_next(self, input):
|
| 18 |
+
# seq generation
|
| 19 |
+
training = False
|
| 20 |
+
seq = np.zeros([self.seq_max_len], dtype=np.int32)
|
| 21 |
+
idx = self.seq_max_len - 1
|
| 22 |
+
idx -= 1
|
| 23 |
+
for i in input[::-1]:
|
| 24 |
+
seq[idx] = i
|
| 25 |
+
idx -= 1
|
| 26 |
+
if idx == -1:
|
| 27 |
+
break
|
| 28 |
+
|
| 29 |
+
input_seq = np.array([seq])
|
| 30 |
+
candidate = np.expand_dims(np.arange(1, self.item_num + 1, 1), axis=0)
|
| 31 |
+
|
| 32 |
+
mask = tf.expand_dims(tf.cast(tf.not_equal(input_seq, 0), tf.float32),
|
| 33 |
+
-1)
|
| 34 |
+
seq_embeddings, positional_embeddings = self.embedding(input_seq)
|
| 35 |
+
seq_embeddings += positional_embeddings
|
| 36 |
+
seq_embeddings *= mask
|
| 37 |
+
seq_attention = seq_embeddings
|
| 38 |
+
seq_attention = self.encoder(seq_attention, training, mask)
|
| 39 |
+
seq_attention = self.layer_normalization(seq_attention) # (b, s, d)
|
| 40 |
+
seq_emb = tf.reshape(
|
| 41 |
+
seq_attention,
|
| 42 |
+
[tf.shape(input_seq)[0] * self.seq_max_len, self.embedding_dim],
|
| 43 |
+
) # (b*s, d)
|
| 44 |
+
candidate_emb = self.item_embedding_layer(candidate) # (b, s, d)
|
| 45 |
+
candidate_emb = tf.transpose(candidate_emb, perm=[0, 2, 1]) # (b, d, s)
|
| 46 |
+
|
| 47 |
+
test_logits = tf.matmul(seq_emb, candidate_emb)
|
| 48 |
+
test_logits = tf.reshape(
|
| 49 |
+
test_logits,
|
| 50 |
+
[tf.shape(input_seq)[0], self.seq_max_len, self.item_num],
|
| 51 |
+
)
|
| 52 |
+
test_logits = test_logits[:, -1, :] # (1, 101)
|
| 53 |
+
|
| 54 |
+
predictions = np.array(test_logits)[0]
|
| 55 |
+
|
| 56 |
+
return predictions
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def load_model():
|
| 60 |
+
model_config = {
|
| 61 |
+
"MAXLEN": 50,
|
| 62 |
+
"NUM_BLOCKS": 2, # NUMBER OF TRANSFORMER BLOCKS
|
| 63 |
+
"HIDDEN_UNITS": 100, # NUMBER OF UNITS IN THE ATTENTION CALCULATION
|
| 64 |
+
"NUM_HEADS": 1, # NUMBER OF ATTENTION HEADS
|
| 65 |
+
"DROPOUT_RATE": 0.2, # DROPOUT RATE
|
| 66 |
+
"L2_EMB": 0.0, # L2 REGULARIZATION COEFFICIENT
|
| 67 |
+
"NUM_NEG_TEST": 100,
|
| 68 |
+
# NUMBER OF NEGATIVE EXAMPLES PER POSITIVE EXAMPLE
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
model = SASREC_Vessl(
|
| 72 |
+
item_num=12101, # should be changed according to dataset
|
| 73 |
+
seq_max_len=model_config.get("MAXLEN"),
|
| 74 |
+
num_blocks=model_config.get("NUM_BLOCKS"),
|
| 75 |
+
embedding_dim=model_config.get("HIDDEN_UNITS"),
|
| 76 |
+
attention_dim=model_config.get("HIDDEN_UNITS"),
|
| 77 |
+
attention_num_heads=model_config.get("NUM_HEADS"),
|
| 78 |
+
dropout_rate=model_config.get("DROPOUT_RATE"),
|
| 79 |
+
conv_dims=[100, 100],
|
| 80 |
+
l2_reg=model_config.get("L2_EMB"),
|
| 81 |
+
num_neg_test=model_config.get("NUM_NEG_TEST"),
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
if os.path.isfile('best.index') and os.path.isfile('best.data-00000-of-00001'):
|
| 85 |
+
model.load_weights('best')
|
| 86 |
+
|
| 87 |
+
return model
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def main():
|
| 91 |
+
st.title('Self-Attentive Sequential Recommendation(SASRec)')
|
| 92 |
+
model = load_model()
|
| 93 |
+
st.write(model)
|
| 94 |
+
|
| 95 |
+
numbers = st.text_input
|
| 96 |
+
st.write(numbers)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
if __name__ == '__main__':
|
| 100 |
+
main()
|