| import streamlit as st |
| from typing import List |
| import textract |
| import tempfile |
| import spacy |
| import subprocess |
| import scispacy |
| from spacy.tokens import DocBin, Doc, Span |
| from collections import Counter |
| import srsly |
| from spacy.matcher import PhraseMatcher |
|
|
| |
| with open("style.css") as f: |
| st.markdown("<style>" + f.read() + "</style>", unsafe_allow_html=True) |
| |
| st.title('Index and Search a Collection of Documents') |
| if 'query' not in st.session_state: |
| st.session_state['query'] = '' |
|
|
| @st.cache |
| def download_model(language:str, select_model:str): |
| if language == 'Science': |
| urls = srsly.read_json('scispacy.json') |
| subprocess.run(['pip', 'install', f'{urls[select_model]}']) |
| return True |
| else: |
| try: |
| spacy.cli.download(select_model) |
| return True |
| except Exception as e: |
| return False |
|
|
| def search_docs(query:str, documents:List[Doc], nlp) -> List[Span]: |
| terms = query.split('|') |
| patterns = [nlp.make_doc(text) for text in terms] |
| matcher = PhraseMatcher(nlp.vocab) |
| matcher.add(query, patterns) |
| |
| results = [] |
| for doc in documents: |
| matches = matcher(doc) |
| for match in matches: |
| results.append(doc[match[1]:match[2]]) |
| |
| return results |
|
|
| def update_query(arg:str): |
| st.session_state.query = arg |
|
|
| models = srsly.read_json('models.json') |
| models[''] = [] |
| languages = models.keys() |
| language = st.selectbox("Language", languages, index=len(models.keys())-1, help="Select the language of your materials.") |
| if language: |
| select_model = st.selectbox("Model", models[language], help="spaCy model") |
| if select_model: |
| model_downloaded = download_model(language, select_model) |
|
|
| if model_downloaded: |
|
|
| nlp = spacy.load(select_model) |
|
|
| nlp.max_length = 1200000 |
| |
|
|
| uploaded_files = st.file_uploader("Select files to process", accept_multiple_files=True) |
| st.session_state.query = st.sidebar.text_input(label="Enter your query (use | to separate search terms)", value="...") |
| |
| documents = [] |
| all_ents = [] |
| for uploaded_file in uploaded_files: |
| file_type = uploaded_file.type |
| file_suffix = '.' + uploaded_file.name.split('.')[-1] |
| temp = tempfile.NamedTemporaryFile(suffix=file_suffix,) |
| temp.write(uploaded_file.getvalue()) |
| try: |
| text = textract.process(temp.name) |
| text = text.decode('utf-8') |
| doc = nlp(text) |
| doc.user_data['filename'] = uploaded_file.name |
| documents.append(doc) |
| for ent in doc.ents: |
| all_ents.append(ent) |
| |
| |
| except Exception as e: |
| st.error(e) |
|
|
| ents_container = st.container() |
| label_freq = Counter([ent.label_ for ent in all_ents]) |
| for key, value in label_freq.items(): |
| if st.sidebar.button(key, key=key): |
| st.sidebar.write(value) |
| text_freq = Counter([ent.text for ent in all_ents if ent.label_ == key]) |
| for text in text_freq.keys(): |
| st.sidebar.button(f'{text} ({text_freq[text]})', on_click=update_query, args=(text, )) |
| |
| results_container = st.container() |
| results = search_docs(st.session_state.query, documents,nlp) |
| for result in results: |
| doc = result.doc |
| sent_before = doc[result.sent.start:result.start] |
| sent_after = doc[result.end:result.sent.end] |
| results_container.markdown(f""" |
| <div style="border: 2px solid #202d89;border-radius: 15px;"><p>{result.doc.user_data['filename']}</p> |
| <div class='text'>{sent_before.text} <span class="text_mark"> {result.text}</span>{sent_after.text}</div> |
| </div> |
| """, unsafe_allow_html=True) |
| |
| |
|
|
|
|