| from flask import Flask, render_template, request, jsonify |
| from transformers import BertForSequenceClassification, BertTokenizer |
| import torch |
|
|
| app = Flask(__name__) |
|
|
| |
| model_state_dict = torch.load("bert_classifier_three_labeled.pth") |
|
|
| |
| model = BertForSequenceClassification.from_pretrained('bert-base-uncased') |
|
|
| |
| model.load_state_dict(model_state_dict) |
|
|
| |
| tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') |
|
|
| def predict(prompt): |
| inputs = tokenizer(prompt, return_tensors="pt") |
| outputs = model(**inputs) |
| probs = torch.nn.functional.softmax(outputs.logits, dim=-1) |
| return probs[0].tolist() |
|
|
| @app.route('/', methods=['GET', 'POST']) |
| def index(): |
| result = None |
| if request.method == 'POST': |
| prompt = request.form['prompt'] |
| result = predict(prompt) |
| return render_template('index.html', result=result) |
|
|
| if __name__ == '__main__': |
| app.run(debug=True) |
|
|