Spaces:
Runtime error
Runtime error
ugmSorcero
commited on
Commit
·
6c3736e
1
Parent(s):
e4aa90a
final touches to draw pipelines & manual cache
Browse files- core/pipelines.py +0 -10
- interface/components.py +10 -4
- interface/config.py +3 -1
- interface/draw_pipelines.py +31 -9
- interface/pages.py +4 -4
core/pipelines.py
CHANGED
|
@@ -2,15 +2,12 @@
|
|
| 2 |
Haystack Pipelines
|
| 3 |
"""
|
| 4 |
|
| 5 |
-
import tokenizers
|
| 6 |
from haystack import Pipeline
|
| 7 |
from haystack.document_stores import InMemoryDocumentStore
|
| 8 |
from haystack.nodes.retriever import DensePassageRetriever, TfidfRetriever
|
| 9 |
from haystack.nodes.preprocessor import PreProcessor
|
| 10 |
-
import streamlit as st
|
| 11 |
|
| 12 |
|
| 13 |
-
@st.cache(allow_output_mutation=True)
|
| 14 |
def keyword_search(
|
| 15 |
index="documents",
|
| 16 |
):
|
|
@@ -42,13 +39,6 @@ def keyword_search(
|
|
| 42 |
return search_pipeline, index_pipeline
|
| 43 |
|
| 44 |
|
| 45 |
-
@st.cache(
|
| 46 |
-
hash_funcs={
|
| 47 |
-
tokenizers.Tokenizer: lambda _: None,
|
| 48 |
-
tokenizers.AddedToken: lambda _: None,
|
| 49 |
-
},
|
| 50 |
-
allow_output_mutation=True,
|
| 51 |
-
)
|
| 52 |
def dense_passage_retrieval(
|
| 53 |
index="documents",
|
| 54 |
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
|
|
|
|
| 2 |
Haystack Pipelines
|
| 3 |
"""
|
| 4 |
|
|
|
|
| 5 |
from haystack import Pipeline
|
| 6 |
from haystack.document_stores import InMemoryDocumentStore
|
| 7 |
from haystack.nodes.retriever import DensePassageRetriever, TfidfRetriever
|
| 8 |
from haystack.nodes.preprocessor import PreProcessor
|
|
|
|
| 9 |
|
| 10 |
|
|
|
|
| 11 |
def keyword_search(
|
| 12 |
index="documents",
|
| 13 |
):
|
|
|
|
| 39 |
return search_pipeline, index_pipeline
|
| 40 |
|
| 41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
def dense_passage_retrieval(
|
| 43 |
index="documents",
|
| 44 |
query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
|
interface/components.py
CHANGED
|
@@ -13,10 +13,16 @@ def component_select_pipeline(container):
|
|
| 13 |
if "Keyword Search" in pipeline_names
|
| 14 |
else 0,
|
| 15 |
)
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
|
| 22 |
def component_show_pipeline(pipeline):
|
|
|
|
| 13 |
if "Keyword Search" in pipeline_names
|
| 14 |
else 0,
|
| 15 |
)
|
| 16 |
+
if st.session_state["pipeline"] is None or st.session_state["pipeline"]["name"] != selected_pipeline:
|
| 17 |
+
(
|
| 18 |
+
search_pipeline,
|
| 19 |
+
index_pipeline,
|
| 20 |
+
) = pipeline_funcs[pipeline_names.index(selected_pipeline)]()
|
| 21 |
+
st.session_state["pipeline"] = {
|
| 22 |
+
'name': selected_pipeline,
|
| 23 |
+
'search_pipeline': search_pipeline,
|
| 24 |
+
'index_pipeline': index_pipeline,
|
| 25 |
+
}
|
| 26 |
|
| 27 |
|
| 28 |
def component_show_pipeline(pipeline):
|
interface/config.py
CHANGED
|
@@ -1,7 +1,9 @@
|
|
| 1 |
from interface.pages import page_landing_page, page_search, page_index
|
| 2 |
|
| 3 |
# Define default Session Variables over the whole session.
|
| 4 |
-
session_state_variables = {
|
|
|
|
|
|
|
| 5 |
|
| 6 |
# Define Pages for the demo
|
| 7 |
pages = {
|
|
|
|
| 1 |
from interface.pages import page_landing_page, page_search, page_index
|
| 2 |
|
| 3 |
# Define default Session Variables over the whole session.
|
| 4 |
+
session_state_variables = {
|
| 5 |
+
"pipeline": None
|
| 6 |
+
}
|
| 7 |
|
| 8 |
# Define Pages for the demo
|
| 9 |
pages = {
|
interface/draw_pipelines.py
CHANGED
|
@@ -3,11 +3,9 @@ from typing import List
|
|
| 3 |
from itertools import chain
|
| 4 |
import networkx as nx
|
| 5 |
import plotly.graph_objs as go
|
| 6 |
-
import streamlit as st
|
| 7 |
import numpy as np
|
| 8 |
|
| 9 |
|
| 10 |
-
@st.cache(allow_output_mutation=True)
|
| 11 |
def get_pipeline_graph(pipeline):
|
| 12 |
# Controls for how the graph is drawn
|
| 13 |
nodeColor = "#ffbf00"
|
|
@@ -16,13 +14,37 @@ def get_pipeline_graph(pipeline):
|
|
| 16 |
lineColor = "#ffffff"
|
| 17 |
|
| 18 |
G = pipeline.graph
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
for node in G.nodes:
|
| 27 |
G.nodes[node]["pos"] = list(pos[node])
|
| 28 |
|
|
|
|
| 3 |
from itertools import chain
|
| 4 |
import networkx as nx
|
| 5 |
import plotly.graph_objs as go
|
|
|
|
| 6 |
import numpy as np
|
| 7 |
|
| 8 |
|
|
|
|
| 9 |
def get_pipeline_graph(pipeline):
|
| 10 |
# Controls for how the graph is drawn
|
| 11 |
nodeColor = "#ffbf00"
|
|
|
|
| 14 |
lineColor = "#ffffff"
|
| 15 |
|
| 16 |
G = pipeline.graph
|
| 17 |
+
current_coordinate = (0, len(set([edge[0] for edge in G.edges()])) + 1)
|
| 18 |
+
# Transform G.edges into {node : all_connected_nodes} format
|
| 19 |
+
node_connections = {}
|
| 20 |
+
for in_node, out_node in G.edges():
|
| 21 |
+
if in_node in node_connections:
|
| 22 |
+
node_connections[in_node].append(out_node)
|
| 23 |
+
else:
|
| 24 |
+
node_connections[in_node] = [out_node]
|
| 25 |
+
# Get node coordinates/pos
|
| 26 |
+
fixed_pos_nodes = {}
|
| 27 |
+
for idx, (in_node, out_nodes) in enumerate(node_connections.items()):
|
| 28 |
+
if in_node not in fixed_pos_nodes:
|
| 29 |
+
fixed_pos_nodes[in_node] = np.array([current_coordinate[0], current_coordinate[1]])
|
| 30 |
+
current_coordinate = (current_coordinate[0], current_coordinate[1] - 1)
|
| 31 |
+
# If more than 1 out node, then branch out in X coordinate
|
| 32 |
+
if len(out_nodes) > 1:
|
| 33 |
+
# if length is odd
|
| 34 |
+
if (len(out_nodes) % 2) != 0:
|
| 35 |
+
middle_node = out_nodes[round(len(out_nodes)/2, 0) - 1]
|
| 36 |
+
fixed_pos_nodes[middle_node] = np.array([current_coordinate[0], current_coordinate[1]])
|
| 37 |
+
out_nodes = [n for n in out_nodes if n != middle_node]
|
| 38 |
+
correction_coordinate = - len(out_nodes) / 2
|
| 39 |
+
for out_node in out_nodes:
|
| 40 |
+
fixed_pos_nodes[out_node] = np.array([int(current_coordinate[0] + correction_coordinate), int(current_coordinate[1])])
|
| 41 |
+
if correction_coordinate == -1:
|
| 42 |
+
correction_coordinate += 1
|
| 43 |
+
correction_coordinate += 1
|
| 44 |
+
current_coordinate = (current_coordinate[0], current_coordinate[1] - 1)
|
| 45 |
+
elif len(node_connections) - 1 == idx:
|
| 46 |
+
fixed_pos_nodes[out_nodes[0]] = np.array([current_coordinate[0], current_coordinate[1]])
|
| 47 |
+
pos = nx.spring_layout(G, pos=fixed_pos_nodes, fixed=G.nodes(), seed=42)
|
| 48 |
for node in G.nodes:
|
| 49 |
G.nodes[node]["pos"] = list(pos[node])
|
| 50 |
|
interface/pages.py
CHANGED
|
@@ -36,12 +36,12 @@ def page_search(container):
|
|
| 36 |
## SEARCH ##
|
| 37 |
query = st.text_input("Query")
|
| 38 |
|
| 39 |
-
component_show_pipeline(st.session_state["search_pipeline"])
|
| 40 |
|
| 41 |
if st.button("Search"):
|
| 42 |
st.session_state["search_results"] = search(
|
| 43 |
queries=[query],
|
| 44 |
-
pipeline=st.session_state["search_pipeline"],
|
| 45 |
)
|
| 46 |
if "search_results" in st.session_state:
|
| 47 |
component_show_search_result(
|
|
@@ -53,7 +53,7 @@ def page_index(container):
|
|
| 53 |
with container:
|
| 54 |
st.title("Index time!")
|
| 55 |
|
| 56 |
-
component_show_pipeline(st.session_state["index_pipeline"])
|
| 57 |
|
| 58 |
input_funcs = {
|
| 59 |
"Raw Text": (component_text_input, "card-text"),
|
|
@@ -74,7 +74,7 @@ def page_index(container):
|
|
| 74 |
if st.button("Index"):
|
| 75 |
index_results = index(
|
| 76 |
corpus,
|
| 77 |
-
st.session_state["index_pipeline"],
|
| 78 |
)
|
| 79 |
if index_results:
|
| 80 |
st.write(index_results)
|
|
|
|
| 36 |
## SEARCH ##
|
| 37 |
query = st.text_input("Query")
|
| 38 |
|
| 39 |
+
component_show_pipeline(st.session_state["pipeline"]["search_pipeline"])
|
| 40 |
|
| 41 |
if st.button("Search"):
|
| 42 |
st.session_state["search_results"] = search(
|
| 43 |
queries=[query],
|
| 44 |
+
pipeline=st.session_state["pipeline"]["search_pipeline"],
|
| 45 |
)
|
| 46 |
if "search_results" in st.session_state:
|
| 47 |
component_show_search_result(
|
|
|
|
| 53 |
with container:
|
| 54 |
st.title("Index time!")
|
| 55 |
|
| 56 |
+
component_show_pipeline(st.session_state["pipeline"]["index_pipeline"])
|
| 57 |
|
| 58 |
input_funcs = {
|
| 59 |
"Raw Text": (component_text_input, "card-text"),
|
|
|
|
| 74 |
if st.button("Index"):
|
| 75 |
index_results = index(
|
| 76 |
corpus,
|
| 77 |
+
st.session_state["pipeline"]["index_pipeline"],
|
| 78 |
)
|
| 79 |
if index_results:
|
| 80 |
st.write(index_results)
|