Spaces:
Runtime error
Runtime error
| import time | |
| import streamlit as st | |
| from annotated_text import annotated_text | |
| from flair.data import Sentence | |
| from flair.models import SequenceTagger | |
| checkpoints = [ | |
| "flair/pos-english", | |
| ] | |
| colors = {'ADD': '#b9d9a6', 'AFX': '#eddc92', 'CC': '#95e9d7', 'CD': '#e797db', 'DT': '#9ff48b', 'EX': '#ed92b4', 'FW': '#decfa1', 'HYPH': '#ada7d7', 'IN': '#85fad8', 'JJ': '#8ba4f4', 'JJR': '#e7a498', 'JJS': '#e5c79a', 'LS': '#eb94b6', 'MD': '#e698ae', 'NFP': '#d9d1a6', 'NN': '#96e89f', 'NNP': '#e698c6', 'NNPS': '#ddbfa2', 'NNS': '#f788cd', 'PDT': '#f19c8d', 'POS': '#8ed5f0', 'PRP': '#c4d8a6', 'PRP$': '#e39bdc', 'RB': '#8df1e2', 'RBR': '#d7f787', 'RBS': '#f986f0', 'RP': '#878df8', 'SYM': '#83fe80', 'TO': '#a6d8c9', 'UH': '#d9a6cc', 'VB': '#a1deda', 'VBD': '#8fefe1', 'VBG': '#e3c79b', 'VBN': '#fb81fe', 'VBP': '#d5fe81', 'VBZ': '#8084ff', 'WDT': '#dd80fe', 'WP': '#9ce3e3', 'WP$': '#9fbddf', 'WRB': '#dea1b5', 'XX': '#93b8ec'} | |
| def get_model(model_name): | |
| return SequenceTagger.load(model_name) # Load the model | |
| def getPos(s: Sentence): | |
| texts = [] | |
| labels = [] | |
| for t in s.tokens: | |
| for label in t.annotation_layers.keys(): | |
| texts.append(t.text) | |
| labels.append(t.get_labels(label)[0].value) | |
| return texts, labels | |
| def getDictFromPOS(texts, labels): | |
| return [{ "text": t, "label": l } for t, l in zip(texts, labels)] | |
| def getAnnotatedFromPOS(texts, labels): | |
| return [(t,l,colors[l]) for t, l in zip(texts, labels)] | |
| def main(): | |
| st.title("Part of Speech Categorizer") | |
| st.write("Paste or type text, submit and the machine will attempt to identify parts of speech.") | |
| checkpoint = st.selectbox("Choose model", checkpoints) | |
| model = get_model(checkpoint) | |
| default_text = "This is an example sentence." | |
| input_text = st.text_area( | |
| label="Original text", | |
| value=default_text, | |
| ) | |
| start = None | |
| if st.button("Submit"): | |
| start = time.time() | |
| with st.spinner("Computing"): | |
| # Build Sentence | |
| s = Sentence(input_text) | |
| # predict tags | |
| model.predict(s) | |
| try: | |
| texts, labels = getPos(s) | |
| st.header("Labels:") | |
| anns = getAnnotatedFromPOS(texts, labels) | |
| annotated_text(*anns) | |
| st.header("JSON:") | |
| st.json(getDictFromPOS(texts, labels)) | |
| except Exception as e: | |
| st.error("Some error occured!" + str(e)) | |
| st.stop() | |
| st.write("---") | |
| if start is not None: | |
| st.text(f"prediction took {time.time() - start:.2f}s") | |
| if __name__ == "__main__": | |
| main() |