File size: 8,133 Bytes
ce390aa
 
 
a60db42
ce390aa
 
 
 
 
 
 
494c66d
 
 
 
 
 
 
 
 
 
fe8e1e4
ce390aa
 
 
 
475ca00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ce390aa
 
 
9217007
 
 
 
 
fe8e1e4
9217007
475ca00
ce390aa
 
 
 
 
a384043
ce390aa
 
 
 
 
 
 
fe8e1e4
ce390aa
 
 
 
6df7924
fe8e1e4
 
 
 
ce390aa
494c66d
ce390aa
 
 
fe8e1e4
ce390aa
475ca00
 
 
 
fe8e1e4
ce390aa
 
fe8e1e4
 
 
 
 
 
 
 
 
ce390aa
 
 
 
494c66d
d8e2c52
494c66d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
475ca00
 
494c66d
 
ce390aa
494c66d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d8e2c52
494c66d
 
 
 
 
 
 
 
 
 
 
 
 
 
ce390aa
494c66d
 
 
ce390aa
494c66d
 
 
 
 
 
 
 
 
 
0ed8010
44d845f
8d88da3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import os
os.environ['HF_HOME'] = '/tmp'
import time
import streamlit as st
import pandas as pd
import io
import plotly.express as px
import hashlib
from gliner import GLiNER
from streamlit_extras.stylable_container import stylable_container
from comet_ml import Experiment
# A new function to generate a stable color for a given string (label)
def get_stable_color(s):
    """
    Generates a consistent, stable color for a given string.
    This ensures the same label always has the same color in the treemap.
    """
    hash_object = hashlib.sha256(s.encode('utf-8'))
    hex_digest = hash_object.hexdigest()
    # Use the first 6 hex digits for RGB color
    return f'#{hex_digest[:6]}'
# --- Page Configuration and UI Elements
st.set_page_config(layout="wide", page_title="Named Entity Recognition App")
st.subheader("InfoFinder", divider="violet")
st.link_button("by nlpblogs", "https://nlpblogs.com", type="tertiary")
expander = st.expander("**Important notes**")
expander.write("""**How to Use:** 

1. Type or paste your text into the text area below, then press Ctrl + Enter. 
2. Click the 'Add Question' button to add your question to the Record of Questions. You can manage your questions by deleting them one by one.
3. Click the 'Extract Answers' button to extract the answer to your question. 

Results are presented in an easy-to-read table, visualized in an interactive tree map and are available for download.

**Usage Limits:** You can request results unlimited times for one (1) month.

**Supported Languages:** English 

**Technical issues:** If your connection times out, please refresh the page or reopen the app's URL. 

For any errors or inquiries, please contact us at info@nlpblogs.com""")

with st.sidebar:
    st.write("Use the following code to embed the InfoFinder web app on your website. Feel free to adjust the width and height values to fit your page.")
    code = '''
    <iframe
	src="https://aiecosystem-infofinder.hf.space"
	frameborder="0"
	width="850"
	height="450"
    ></iframe>

    '''
    st.code(code, language="html")
    st.text("")
    st.text("")
    st.divider()
    st.subheader("🚀 Ready to build your own AI Web App?", divider="violet")
    st.link_button("AI Web App Builder", "https://nlpblogs.com/build-your-named-entity-recognition-app/", type="primary")
# --- Comet ML Setup ---
COMET_API_KEY = os.environ.get("COMET_API_KEY")
COMET_WORKSPACE = os.environ.get("COMET_WORKSPACE")
COMET_PROJECT_NAME = os.environ.get("COMET_PROJECT_NAME")
comet_initialized = bool(COMET_API_KEY and COMET_WORKSPACE and COMET_PROJECT_NAME)
if not comet_initialized:
    st.warning("Comet ML not initialized. Check environment variables.")
# --- Initialize session state for labels
if 'user_labels' not in st.session_state:
    st.session_state.user_labels = []
# --- Model Loading and Caching ---
@st.cache_resource
def load_gliner_model():
    """
    Initializes and caches the GLiNER model.
    This ensures the model is only loaded once, improving performance.
    """
    try:
        return GLiNER.from_pretrained("knowledgator/gliner-multitask-large-v0.5", device="cpu")
    except Exception as e:
        st.error(f"Error loading the GLiNER model: {e}")
        st.stop()
# Load the model
model = load_gliner_model()
word_limit = 200
user_text = st.text_area(f"Type or paste your text below (max {word_limit} words), and then press Ctrl + Enter", height=250, key='my_text_area')
word_count = len(user_text.split())
st.markdown(f"**Word count:** {word_count}/{word_limit}")
def clear_text():
    """Clears the text area by resetting its value in session state."""
    st.session_state['my_text_area'] = ""
st.button("Clear text", on_click=clear_text)
st.subheader("Question-Answering", divider = "violet")
# Replaced two columns with a single text input
question_input = st.text_input("Ask wh-questions. **Wh-questions begin with what, when, where, who, whom, which, whose, why and how. We use them to ask for specific information.**")
if st.button("Add Question"):
    if question_input:
        if question_input not in st.session_state.user_labels:
            st.session_state.user_labels.append(question_input)
            st.success(f"Added question: {question_input}")
        else:
            st.warning("This question has already been added.")
    else:
        st.warning("Please enter a question.")
st.markdown("---")
st.subheader("Record of Questions", divider="violet")
if st.session_state.user_labels:
    for i, label in enumerate(st.session_state.user_labels):
        col_list, col_delete = st.columns([0.9, 0.1])
        with col_list:
            st.write(f"- {label}", key=f"label_{i}")
        with col_delete:
            if st.button("Delete", key=f"delete_{i}"):
                st.session_state.user_labels.pop(i)
                st.rerun()
else:
    st.info("No questions defined yet. Use the input above to add one.")
st.divider()
if st.button("Extract Answers"):
    if not user_text.strip():
        st.warning("Please enter some text to analyze.")
    elif word_count > word_limit:
        st.warning(f"Your text exceeds the {word_limit} word limit. Please shorten it to continue.")
    elif not st.session_state.user_labels:
        st.warning("Please define at least one question.")
    else:
        if comet_initialized:
            experiment = Experiment(api_key=COMET_API_KEY, workspace=COMET_WORKSPACE, project_name=COMET_PROJECT_NAME)
            experiment.log_parameter("input_text_length", len(user_text))
            experiment.log_parameter("defined_labels", st.session_state.user_labels)
        start_time = time.time()
        with st.spinner("Analyzing text...", show_time=True):
            try:
                # Corrected: Changed model_qa to model
                entities = model.predict_entities(user_text, st.session_state.user_labels)
                end_time = time.time()
                elapsed_time = end_time - start_time
                st.info(f"Processing took **{elapsed_time:.2f} seconds**.")
                if entities:
                    df1 = pd.DataFrame(entities)
                    df2 = df1[['label', 'text', 'score']]
                    df = df2.rename(columns={'label': 'question', 'text': 'answer'})
                    st.subheader("Extracted Answers", divider="violet")
                    st.dataframe(df, use_container_width=True)
                    st.subheader("Tree map", divider="green")
                    all_labels = df['question'].unique()
                    label_color_map = {label: get_stable_color(label) for label in all_labels}
                    fig_treemap = px.treemap(df, path=[px.Constant("all"), 'question', 'answer'], values='score', color='question', color_discrete_map=label_color_map)
                    fig_treemap.update_layout(margin=dict(t=50, l=25, r=25, b=25), paper_bgcolor='#F3E5F5', plot_bgcolor='#F3E5F5')
                    st.plotly_chart(fig_treemap)
                    csv_data = df.to_csv(index=False).encode('utf-8')
                    st.download_button(
                        label="Download CSV",
                        data=csv_data,
                        file_name="nlpblogs_questions_answers.csv",
                        mime="text/csv",
                    )
                    if comet_initialized:
                        experiment.log_metric("processing_time_seconds", elapsed_time)
                        experiment.log_table("predicted_entities", df)
                        experiment.log_figure(figure=fig_treemap, figure_name="entity_treemap")
                        experiment.end()
                else:
                    st.info("No answers were found in the text with the defined questions.")
                    if comet_initialized:
                        experiment.end()
            except Exception as e:
                st.error(f"An error occurred during processing: {e}")
                st.write(f"Error details: {e}")
                if comet_initialized:
                    experiment.log_text(f"Error: {e}")
                    experiment.end()