| from user_model.configuration import UserModelConfig | |
| from transformers import PreTrainedModel | |
| import tensorflow as tf | |
| class UserModel(PreTrainedModel): | |
| config_class = UserModelConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.model = tf.saved_model.load('tf_retrieval_user_model') | |
| def forward(self, user_id): | |
| return self.model(user_id) |