Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| from transformers import BertTokenizer, BertModel | |
| from nn_factory import nn_factory | |
| from huggingface_hub import hf_hub_download | |
| class BERT_classifier(nn.Module): | |
| def __init__(self, bertmodel, num_score): | |
| super(BERT_classifier, self).__init__() | |
| self.bertmodel = bertmodel | |
| self.dropout = nn.Dropout(p=bertmodel.config.hidden_dropout_prob) | |
| self.linear = nn.Linear(bertmodel.config.hidden_size, num_score) | |
| def forward(self, wrapped_input): | |
| hidden = self.bertmodel(**wrapped_input) | |
| _, pooler_output = hidden[0], hidden[1] | |
| output_value = self.linear(pooler_output).squeeze() | |
| score = torch.sigmoid(output_value) * 1000 | |
| return score | |
| tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
| bert = BertModel.from_pretrained("bert-base-uncased") | |
| model_dir = hf_hub_download( | |
| repo_id="ID2223/hackernews_upvotes_predictor_model", | |
| filename="model_1.pt", | |
| repo_type="model" | |
| ) | |
| model = BERT_classifier(bert, 1) | |
| model.load_state_dict(torch.load(model_dir, map_location=torch.device('cpu'))) | |
| model.eval() | |
| nn_obj = nn_factory(model, 'cpu', tokenizer) | |
| def predict_score(title: str) -> int: | |
| predicted_score = nn_obj.predict(title) | |
| return int(predicted_score) | |
| with gr.Blocks() as iface: | |
| with gr.Column(): | |
| with gr.Column(): | |
| title = gr.Textbox(label="Title") | |
| with gr.Row(): | |
| button = gr.Button("Submit", variant="primary") | |
| clear = gr.Button("Clear") | |
| with gr.Column(): | |
| output = gr.Slider(label="Possible score", minimum=0, maximum=1000, step=1) | |
| button.click(predict_score, [title], output) | |
| iface.launch() | |