| | import os.path
|
| |
|
| | import torch
|
| | import torch.nn as nn
|
| | from transformers import RobertaTokenizerFast, RobertaForMaskedLM
|
| | import streamlit as st
|
| |
|
| |
|
| | class SimpleClassifier(nn.Module):
|
| | def __init__(self, in_features: int, hidden_features: int,
|
| | out_features: int, activation=nn.ReLU()):
|
| | super().__init__()
|
| | self.bn = nn.BatchNorm1d(in_features)
|
| | self.in2hid = nn.Linear(in_features, hidden_features)
|
| | self.activation = activation
|
| | self.hid2hid = nn.Linear(hidden_features, hidden_features)
|
| | self.hid2out = nn.Linear(hidden_features, out_features)
|
| |
|
| |
|
| |
|
| | self.bn2 = nn.BatchNorm1d(hidden_features)
|
| |
|
| | def forward(self, X):
|
| | X = self.bn(X)
|
| | X = self.in2hid(X)
|
| |
|
| | X = self.activation(X)
|
| | X = self.hid2hid(torch.concat((X,), 1))
|
| |
|
| | X = self.activation(X)
|
| | X = self.hid2out(torch.concat((X,), 1))
|
| |
|
| | X = nn.functional.sigmoid(X)
|
| | return X
|
| |
|
| |
|
| | @st.cache_data()
|
| | def load_models():
|
| | model = RobertaForMaskedLM.from_pretrained("roberta-base")
|
| | model.lm_head = nn.Identity()
|
| | tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
|
| | my_classifier = SimpleClassifier(768, 768, 1)
|
| | weights_path = os.path.join(__file__, "..", "twitter_model_91_5-.pth")
|
| | my_classifier.load_state_dict(torch.load(weights_path, map_location=device))
|
| | my_classifier.eval()
|
| | return {
|
| | "tokenizer": tokenizer,
|
| | "model": model,
|
| | "classifier": my_classifier
|
| | }
|
| |
|
| |
|
| | def classify_text(text: str) -> float:
|
| | models = load_models()
|
| | tokenizer, model, classifier = models["tokenizer"], models["model"], models["classifier"]
|
| |
|
| | X = tokenizer(
|
| | text,
|
| | truncation=True,
|
| | max_length=128,
|
| | return_tensors='pt'
|
| | )["input_ids"]
|
| |
|
| | X = model.forward(X)[-1][0].sum(axis=0)[None, :]
|
| | return classifier(X)
|
| |
|
| |
|
| | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| |
|