Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import torch | |
| import numpy as np | |
| from transformers import AutoTokenizer | |
| from transformers import BertForSequenceClassification | |
| st.set_page_config(layout='wide', initial_sidebar_state='expanded') | |
| col1, col2= st.columns(2) | |
| with col1: | |
| st.title("FireWatch") | |
| st.markdown("PREDICT WHETHER HEAT SIGNATURES AROUND THE GLOBE ARE LIKELY TO BE FIRES!") | |
| st.markdown("Traing Code at:") | |
| st.markdown("https://colab.research.google.com/drive/1-IfOMJ-X8MKzwm3UjbJbK6RmhT7tk_ye?usp=sharing") | |
| st.markdown("Try the Model Yourself at:") | |
| st.markdown("https://colab.research.google.com/drive/1GmweeQrkzs0OXQ_KNZsWd1PQVRLCWDKi?usp=sharing") | |
| st.markdown("## Sample Table") | |
| table_html = """ | |
| <table style="border-collapse: collapse; width: 100%;"> | |
| <tr style="border: 1px solid orange;"> | |
| <th style="border: 1px solid orange; font-weight: bold;">Category</th> | |
| <th style="border: 1px solid orange; font-weight: bold;">Latitude, Longitude, Brightness, FRP</th> | |
| </tr> | |
| <tr style="border: 1px solid orange;"> | |
| <td style="border: 1px solid orange;">Likely</td> | |
| <td style="border: 1px solid orange;">-26.76123, 147.15512, 393.02, 203.63</td> | |
| </tr> | |
| <tr style="border: 1px solid orange;"> | |
| <td style="border: 1px solid orange;">Likely</td> | |
| <td style="border: 1px solid orange;">-26.7598, 147.14514, 361.54, 79.4</td> | |
| </tr> | |
| <tr style="border: 1px solid orange;"> | |
| <td style="border: 1px solid orange;">Unlikely</td> | |
| <td style="border: 1px solid orange;">-25.70059, 149.48932, 313.9, 5.15</td> | |
| </tr> | |
| <tr style="border: 1px solid orange;"> | |
| <td style="border: 1px solid orange;">Unlikely</td> | |
| <td style="border: 1px solid orange;">-24.4318, 151.83102, 307.98, 8.79</td> | |
| </tr> | |
| <tr style="border: 1px solid orange;"> | |
| <td style="border: 1px solid orange;">Unlikely</td> | |
| <td style="border: 1px solid orange;">-23.21878, 148.91298, 314.08, 7.4</td> | |
| </tr> | |
| <tr style="border: 1px solid orange;"> | |
| <td style="border: 1px solid orange;">Likely</td> | |
| <td style="border: 1px solid orange;">7.87518, 19.9241, 316.32, 39.63</td> | |
| </tr> | |
| <tr style="border: 1px solid orange;"> | |
| <td style="border: 1px solid orange;">Unlikely</td> | |
| <td style="border: 1px solid orange;">-20.10942, 148.14326, 314.39, 8.8</td> | |
| </tr> | |
| <tr style="border: 1px solid orange;"> | |
| <td style="border: 1px solid orange;">Unlikely</td> | |
| <td style="border: 1px solid orange;">7.87772, 19.9048, 304.14, 13.43</td> | |
| </tr> | |
| <tr style="border: 1px solid orange;"> | |
| <td style="border: 1px solid orange;">Likely</td> | |
| <td style="border: 1px solid orange;">-20.79866, 124.46834, 366.74, 89.06</td> | |
| </tr> | |
| </table> | |
| """ | |
| st.markdown(table_html, unsafe_allow_html=True) | |
| tree = """ | |
| <div class="pine-tree" style="width: 50%; margin: 0 auto;"> | |
| <div class="tree-top"></div> | |
| <div class="tree-top2"></div> | |
| <div class="tree-bottom"> | |
| <div class="trunk"></div> | |
| </div> | |
| </div> | |
| <style> | |
| .pine-tree { | |
| width: 15vw; | |
| height: 20vw; | |
| position: relative; | |
| display: flex; | |
| justify-content: center; | |
| align-items: center; | |
| } | |
| .tree-top { | |
| width: 0; | |
| height: 0; | |
| border-left: 8vw solid transparent; | |
| border-right: 8vw solid transparent; | |
| border-bottom: 13vw solid green; | |
| position: absolute; | |
| top: 0; | |
| left: 0; | |
| right: 0; | |
| margin: auto; | |
| } | |
| .tree-top2 { | |
| width: 0; | |
| height: 0; | |
| border-left: 8vw solid transparent; | |
| border-right: 8vw solid transparent; | |
| border-bottom: 13vw solid green; | |
| position: absolute; | |
| top: 3vw; | |
| left: 0; | |
| right: 0; | |
| margin: auto; | |
| } | |
| .tree-bottom { | |
| width: 8vw; | |
| height: 10vw; | |
| background-color: brown; | |
| position: absolute; | |
| bottom: 0; | |
| left: 0; | |
| right: 0; | |
| top: 21vw; | |
| margin: auto; | |
| } | |
| .trunk { | |
| width: 3vw; | |
| height: 10vw; | |
| background-color: brown; | |
| position: absolute; | |
| bottom: 0; | |
| left: 0; | |
| right: 0; | |
| margin: auto; | |
| } | |
| </style> | |
| """ | |
| with col2: | |
| def load_model(show_spinner=True): | |
| MODEL_PATH = "NimaKL/FireWatch_tiny_75k" | |
| model = BertForSequenceClassification.from_pretrained(MODEL_PATH) | |
| return model | |
| token_id = [] | |
| attention_masks = [] | |
| def preprocessing(input_text, tokenizer): | |
| ''' | |
| Returns <class transformers.tokenization_utils_base.BatchEncoding> with the following fields: | |
| - input_ids: list of token ids | |
| - token_type_ids: list of token type ids | |
| - attention_mask: list of indices (0,1) specifying which tokens should considered by the model (return_attention_mask = True). | |
| ''' | |
| return tokenizer.encode_plus( | |
| input_text, | |
| add_special_tokens = True, | |
| max_length = 16, | |
| pad_to_max_length = True, | |
| return_attention_mask = True, | |
| return_tensors = 'pt' | |
| ) | |
| def predict(new_sentence): | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # We need Token IDs and Attention Mask for inference on the new sentence | |
| test_ids = [] | |
| test_attention_mask = [] | |
| # Apply the tokenizer | |
| encoding = preprocessing(new_sentence, tokenizer) | |
| # Extract IDs and Attention Mask | |
| test_ids.append(encoding['input_ids']) | |
| test_attention_mask.append(encoding['attention_mask']) | |
| test_ids = torch.cat(test_ids, dim = 0) | |
| test_attention_mask = torch.cat(test_attention_mask, dim = 0) | |
| # Forward pass, calculate logit predictions | |
| with torch.no_grad(): | |
| output = model(test_ids.to(device), token_type_ids = None, attention_mask = test_attention_mask.to(device)) | |
| prediction = 'Likely' if np.argmax(output.logits.cpu().numpy()).flatten().item() == 1 else 'Unlikely' | |
| pred = 'Predicted Class: '+ prediction | |
| return pred | |
| model = load_model() | |
| tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") | |
| with col2: | |
| st.markdown('## Enter Prediction Data in Correct Format "Latitude, Longtitude, Brightness, FRP"') | |
| text = st.text_input('Predition Data: ', 'Example: 8.81064, -65.07661, 328.04, 18.76') | |
| aButton = st.button('Predict') | |
| if text or aButton: | |
| with st.spinner('Wait for it...'): | |
| st.success(predict(text)) | |
| st.markdown(tree, unsafe_allow_html=True) | |