| | import torch |
| | from tokenizers import Tokenizer |
| | from torch.utils.data import DataLoader |
| |
|
| | import streamlit as st |
| | import base64 |
| | from model import CustomDataset, TransformerEncoder |
| |
|
| | st.set_page_config(layout="wide",page_title="TeknoFest We Bears NLP Competition", page_icon="./media/3bears.ico") |
| |
|
| | tag2id = {"O": 0, "olumsuz": 1, "nötr": 2, "olumlu": 3, "org": 4} |
| | id2tag = {value: key for key, value in tag2id.items()} |
| | device = torch.device('cpu') |
| |
|
| | @st.cache_resource |
| | def load_model_to_cpu(_model, path="model.pth"): |
| | checkpoint = torch.load(path, map_location=torch.device('cpu')) |
| | _model.load_state_dict(checkpoint) |
| | return _model |
| | |
| | def get_base64(bin_file): |
| | with open(bin_file, 'rb') as f: |
| | data = f.read() |
| | return base64.b64encode(data).decode() |
| |
|
| | def predict_fonk(model, device, example, tokenizer): |
| | model.to(device) |
| | model.eval() |
| | predictions = [] |
| |
|
| | encodings_prdict = tokenizer.encode(example) |
| |
|
| | predict_texts = [encodings_prdict.tokens] |
| | predict_input_ids = [encodings_prdict.ids] |
| | predict_attention_masks = [encodings_prdict.attention_mask] |
| | predict_token_type_ids = [encodings_prdict.type_ids] |
| | prediction_labels = [encodings_prdict.type_ids] |
| |
|
| | predict_data = CustomDataset(predict_texts, predict_input_ids, predict_attention_masks, predict_token_type_ids, |
| | prediction_labels) |
| |
|
| | predict_loader = DataLoader(predict_data, batch_size=1, shuffle=False) |
| |
|
| | with torch.no_grad(): |
| | for dataset in predict_loader: |
| | batch_input_ids = dataset['input_ids'].to(device) |
| | batch_att_mask = dataset['attention_mask'].to(device) |
| |
|
| |
|
| |
|
| | outputs = model(batch_input_ids, batch_att_mask) |
| | logits = outputs.view(-1, outputs.size(-1)) |
| | _, predicted = torch.max(logits, 1) |
| |
|
| | |
| | predictions.append(predicted) |
| |
|
| | results_list = [] |
| | entity_list = [] |
| | results_dict = {} |
| | trio = zip(predict_loader.dataset[0]["text"], predictions[0].tolist(), predict_attention_masks[0]) |
| | |
| | for i, (token, label, attention) in enumerate(trio): |
| | if attention != 0 and label != 0 and label !=4: |
| | for next_ones in predictions[0].tolist()[i+1:]: |
| | i+=1 |
| | if next_ones == 4: |
| | token = token +" "+ predict_loader.dataset[0]["text"][i] |
| | else:break |
| | if token not in entity_list: |
| | entity_list.append(token) |
| | results_list.append({"entity":token,"sentiment":id2tag.get(label)}) |
| |
|
| |
|
| | results_dict["entity_list"] = entity_list |
| | results_dict["results"] = results_list |
| |
|
| |
|
| | return results_dict |
| |
|
| | model = TransformerEncoder() |
| | model = load_model_to_cpu(model, "model.pth") |
| | tokenizer = Tokenizer.from_file("tokenizer.json") |
| |
|
| | background = get_base64("./media/background.jpg") |
| |
|
| | with open("./style/style.css", "r") as style: |
| | css=f"""<style>{style.read().format(background=background)}</style>""" |
| | st.markdown(css, unsafe_allow_html=True) |
| |
|
| | left, middle, right = st.columns([1,1.5,1]) |
| | main, comps , result = middle.tabs([" ", " ", " "]) |
| | with main: |
| | example = st.text_area(label='Metin Kutusu: ', placeholder="Lütfen Şikayet veya Yorum Metnini Buraya Yazın, daha sonra Predicte tıklayın") |
| |
|
| | if st.button("Predict"): |
| | predict_list = predict_fonk(model=model, device=device, example=example, tokenizer=tokenizer) |
| | |
| | st.write(predict_list) |
| |
|