Anirudha Soni commited on
Commit ·
597fb81
1
Parent(s): fbc59bf
Basic changes
Browse files- app.py +36 -22
- preprocess.py +19 -36
- retriever.py +13 -50
app.py
CHANGED
|
@@ -11,24 +11,28 @@ app.layout = dbc.Container(
|
|
| 11 |
html.H1("Toolkit Document Search", className="mb-4"),
|
| 12 |
dbc.Row([
|
| 13 |
dbc.Col([
|
| 14 |
-
dcc.Input(
|
| 15 |
-
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
html.Br(), html.Br(),
|
| 18 |
-
dbc.Button("Search", id="search-btn", color="primary")
|
| 19 |
-
], md=8)
|
| 20 |
]),
|
| 21 |
html.Hr(),
|
| 22 |
-
html.Div(id="results-area")
|
| 23 |
],
|
| 24 |
-
fluid=True
|
| 25 |
)
|
| 26 |
|
| 27 |
@app.callback(
|
| 28 |
Output("results-area", "children"),
|
| 29 |
Input("search-btn", "n_clicks"),
|
| 30 |
State("query-input", "value"),
|
| 31 |
-
prevent_initial_call=True
|
| 32 |
)
|
| 33 |
def search_callback(_, query):
|
| 34 |
if not query:
|
|
@@ -42,7 +46,11 @@ def search_callback(_, query):
|
|
| 42 |
for idx, r in enumerate(results):
|
| 43 |
meta = r["metadata"]
|
| 44 |
title = meta.get("name") or f"Document {meta.get('document_id')}"
|
| 45 |
-
subtitle =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
cards.append(
|
| 47 |
dbc.Card(
|
| 48 |
[
|
|
@@ -55,26 +63,32 @@ def search_callback(_, query):
|
|
| 55 |
id={"type": "collapse-btn", "index": idx},
|
| 56 |
color="link",
|
| 57 |
n_clicks=0,
|
| 58 |
-
style={"float": "right"}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
),
|
| 60 |
-
html.Div(f"Score: {r['score']:.3f}", style={"float": "right", "marginRight": "1em"})
|
| 61 |
],
|
| 62 |
-
style={"display": "flex", "flexDirection": "column"}
|
| 63 |
-
),
|
| 64 |
-
dbc.CardBody(
|
| 65 |
-
html.P(r["snippet"], style={"fontStyle": "italic"})
|
| 66 |
),
|
|
|
|
| 67 |
dbc.Collapse(
|
| 68 |
dbc.CardBody(
|
| 69 |
-
html.Pre(
|
| 70 |
-
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
),
|
| 73 |
id={"type": "collapse", "index": idx},
|
| 74 |
-
is_open=False
|
| 75 |
),
|
| 76 |
],
|
| 77 |
-
className="mb-3"
|
| 78 |
)
|
| 79 |
)
|
| 80 |
return cards
|
|
@@ -82,7 +96,7 @@ def search_callback(_, query):
|
|
| 82 |
@app.callback(
|
| 83 |
Output({"type": "collapse", "index": MATCH}, "is_open"),
|
| 84 |
Input({"type": "collapse-btn", "index": MATCH}, "n_clicks"),
|
| 85 |
-
State({"type": "collapse", "index": MATCH}, "is_open")
|
| 86 |
)
|
| 87 |
def toggle_collapse(n, is_open):
|
| 88 |
if n:
|
|
@@ -92,4 +106,4 @@ def toggle_collapse(n, is_open):
|
|
| 92 |
if __name__ == "__main__":
|
| 93 |
import os
|
| 94 |
port = int(os.environ.get("PORT", 7860))
|
| 95 |
-
app.run_server(host="0.0.0.0", port=port, debug=False)
|
|
|
|
| 11 |
html.H1("Toolkit Document Search", className="mb-4"),
|
| 12 |
dbc.Row([
|
| 13 |
dbc.Col([
|
| 14 |
+
dcc.Input(
|
| 15 |
+
id="query-input",
|
| 16 |
+
type="text",
|
| 17 |
+
placeholder="Type your question...",
|
| 18 |
+
debounce=True,
|
| 19 |
+
style={"width": "100%", "padding": "10px"},
|
| 20 |
+
),
|
| 21 |
html.Br(), html.Br(),
|
| 22 |
+
dbc.Button("Search", id="search-btn", color="primary"),
|
| 23 |
+
], md=8),
|
| 24 |
]),
|
| 25 |
html.Hr(),
|
| 26 |
+
html.Div(id="results-area"),
|
| 27 |
],
|
| 28 |
+
fluid=True,
|
| 29 |
)
|
| 30 |
|
| 31 |
@app.callback(
|
| 32 |
Output("results-area", "children"),
|
| 33 |
Input("search-btn", "n_clicks"),
|
| 34 |
State("query-input", "value"),
|
| 35 |
+
prevent_initial_call=True,
|
| 36 |
)
|
| 37 |
def search_callback(_, query):
|
| 38 |
if not query:
|
|
|
|
| 46 |
for idx, r in enumerate(results):
|
| 47 |
meta = r["metadata"]
|
| 48 |
title = meta.get("name") or f"Document {meta.get('document_id')}"
|
| 49 |
+
subtitle = (
|
| 50 |
+
f"Created: {meta.get('create_date') or 'N/A'} | "
|
| 51 |
+
f"Published: {meta.get('publish_date') or 'N/A'} | "
|
| 52 |
+
f"Categories: {', '.join(meta.get('categories') or [])}"
|
| 53 |
+
)
|
| 54 |
cards.append(
|
| 55 |
dbc.Card(
|
| 56 |
[
|
|
|
|
| 63 |
id={"type": "collapse-btn", "index": idx},
|
| 64 |
color="link",
|
| 65 |
n_clicks=0,
|
| 66 |
+
style={"float": "right"},
|
| 67 |
+
),
|
| 68 |
+
html.Div(
|
| 69 |
+
f"Score: {r['score']:.3f}",
|
| 70 |
+
style={"float": "right", "marginRight": "1em"},
|
| 71 |
),
|
|
|
|
| 72 |
],
|
| 73 |
+
style={"display": "flex", "flexDirection": "column"},
|
|
|
|
|
|
|
|
|
|
| 74 |
),
|
| 75 |
+
dbc.CardBody(html.P(r["snippet"], style={"fontStyle": "italic"})),
|
| 76 |
dbc.Collapse(
|
| 77 |
dbc.CardBody(
|
| 78 |
+
html.Pre(
|
| 79 |
+
r["full_text"],
|
| 80 |
+
style={
|
| 81 |
+
"whiteSpace": "pre-wrap",
|
| 82 |
+
"maxHeight": "300px",
|
| 83 |
+
"overflowY": "auto",
|
| 84 |
+
},
|
| 85 |
+
)
|
| 86 |
),
|
| 87 |
id={"type": "collapse", "index": idx},
|
| 88 |
+
is_open=False,
|
| 89 |
),
|
| 90 |
],
|
| 91 |
+
className="mb-3",
|
| 92 |
)
|
| 93 |
)
|
| 94 |
return cards
|
|
|
|
| 96 |
@app.callback(
|
| 97 |
Output({"type": "collapse", "index": MATCH}, "is_open"),
|
| 98 |
Input({"type": "collapse-btn", "index": MATCH}, "n_clicks"),
|
| 99 |
+
State({"type": "collapse", "index": MATCH}, "is_open"),
|
| 100 |
)
|
| 101 |
def toggle_collapse(n, is_open):
|
| 102 |
if n:
|
|
|
|
| 106 |
if __name__ == "__main__":
|
| 107 |
import os
|
| 108 |
port = int(os.environ.get("PORT", 7860))
|
| 109 |
+
app.run_server(host="0.0.0.0", port=port, debug=False)
|
preprocess.py
CHANGED
|
@@ -1,60 +1,43 @@
|
|
| 1 |
-
import json
|
| 2 |
-
import
|
| 3 |
-
from sentence_transformers import SentenceTransformer
|
| 4 |
from pathlib import Path
|
| 5 |
|
| 6 |
DATA_DIR = Path("data")
|
| 7 |
|
| 8 |
def load_json(filename):
|
| 9 |
-
|
|
|
|
| 10 |
data = json.load(f)
|
| 11 |
if isinstance(data, dict) and "results" in data:
|
| 12 |
return data["results"]
|
| 13 |
return data if isinstance(data, list) else []
|
| 14 |
|
| 15 |
def extract_text(item):
|
|
|
|
| 16 |
texts = []
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
texts.append(v)
|
| 25 |
return texts
|
| 26 |
|
| 27 |
def chunk_text(text, max_words=80):
|
|
|
|
| 28 |
sentences = re.split(r'(?<=[.!?]) +', text)
|
| 29 |
chunks, cur, count = [], [], 0
|
| 30 |
for s in sentences:
|
| 31 |
words = s.split()
|
| 32 |
-
if len(words) < 5:
|
|
|
|
| 33 |
if count + len(words) > max_words and cur:
|
| 34 |
chunks.append(" ".join(cur))
|
| 35 |
cur, count = [s], len(words)
|
| 36 |
else:
|
| 37 |
-
cur.append(s)
|
| 38 |
-
|
|
|
|
|
|
|
| 39 |
return chunks
|
| 40 |
-
|
| 41 |
-
print("🔄 Loading JSON...")
|
| 42 |
-
content = load_json("Toolkit_Content_results.json")
|
| 43 |
-
resources = load_json("Toolkit_Resources_results.json")
|
| 44 |
-
|
| 45 |
-
docs = []
|
| 46 |
-
for item in content + resources:
|
| 47 |
-
for t in extract_text(item):
|
| 48 |
-
docs.extend(chunk_text(t))
|
| 49 |
-
print(f"✅ Loaded {len(docs)} chunks")
|
| 50 |
-
|
| 51 |
-
print("🔄 Encoding with SentenceTransformer...")
|
| 52 |
-
model = SentenceTransformer("all-MiniLM-L6-v2")
|
| 53 |
-
embeddings = model.encode(docs, convert_to_numpy=True, normalize_embeddings=True)
|
| 54 |
-
|
| 55 |
-
# Save
|
| 56 |
-
print("💾 Saving artifacts...")
|
| 57 |
-
np.save(DATA_DIR/"embeddings.npy", embeddings)
|
| 58 |
-
with open(DATA_DIR/"docs.pkl", "wb") as f:
|
| 59 |
-
pickle.dump(docs, f)
|
| 60 |
-
print("✅ Done!")
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import re
|
|
|
|
| 3 |
from pathlib import Path
|
| 4 |
|
| 5 |
DATA_DIR = Path("data")
|
| 6 |
|
| 7 |
def load_json(filename):
|
| 8 |
+
"""Load a JSON file and return list of records."""
|
| 9 |
+
with open(DATA_DIR / filename, "r", encoding="utf-8") as f:
|
| 10 |
data = json.load(f)
|
| 11 |
if isinstance(data, dict) and "results" in data:
|
| 12 |
return data["results"]
|
| 13 |
return data if isinstance(data, list) else []
|
| 14 |
|
| 15 |
def extract_text(item):
|
| 16 |
+
"""Extract textual fields from a JSON record."""
|
| 17 |
texts = []
|
| 18 |
+
for k in ("text", "description", "body", "content", "name"):
|
| 19 |
+
if k in item and item[k]:
|
| 20 |
+
texts.append(str(item[k]))
|
| 21 |
+
if "content_json" in item and isinstance(item["content_json"], dict):
|
| 22 |
+
for v in item["content_json"].values():
|
| 23 |
+
if isinstance(v, str) and v.strip():
|
| 24 |
+
texts.append(v)
|
|
|
|
| 25 |
return texts
|
| 26 |
|
| 27 |
def chunk_text(text, max_words=80):
|
| 28 |
+
"""Split long text into smaller chunks."""
|
| 29 |
sentences = re.split(r'(?<=[.!?]) +', text)
|
| 30 |
chunks, cur, count = [], [], 0
|
| 31 |
for s in sentences:
|
| 32 |
words = s.split()
|
| 33 |
+
if len(words) < 5:
|
| 34 |
+
continue
|
| 35 |
if count + len(words) > max_words and cur:
|
| 36 |
chunks.append(" ".join(cur))
|
| 37 |
cur, count = [s], len(words)
|
| 38 |
else:
|
| 39 |
+
cur.append(s)
|
| 40 |
+
count += len(words)
|
| 41 |
+
if cur:
|
| 42 |
+
chunks.append(" ".join(cur))
|
| 43 |
return chunks
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
retriever.py
CHANGED
|
@@ -4,6 +4,9 @@ import numpy as np
|
|
| 4 |
from sentence_transformers import SentenceTransformer
|
| 5 |
from collections import defaultdict
|
| 6 |
|
|
|
|
|
|
|
|
|
|
| 7 |
DATA_DIR = Path("data")
|
| 8 |
DATA_DIR.mkdir(exist_ok=True)
|
| 9 |
EMB_FILE = DATA_DIR / "embeddings.npy"
|
|
@@ -13,75 +16,35 @@ META_FILE = DATA_DIR / "metadata.pkl" # metadata for each doc
|
|
| 13 |
CONTENT_FILE = DATA_DIR / "Toolkit_Content_results.json"
|
| 14 |
RESOURCES_FILE = DATA_DIR / "Toolkit_Resources_results.json"
|
| 15 |
|
|
|
|
| 16 |
model = SentenceTransformer("all-MiniLM-L6-v2")
|
| 17 |
|
| 18 |
-
# ---------
|
| 19 |
-
def _load_json(filename):
|
| 20 |
-
with open(filename, "r", encoding="utf-8") as f:
|
| 21 |
-
data = json.load(f)
|
| 22 |
-
if isinstance(data, dict) and "results" in data:
|
| 23 |
-
return data["results"]
|
| 24 |
-
return data if isinstance(data, list) else []
|
| 25 |
-
|
| 26 |
-
def _extract_text(item):
|
| 27 |
-
texts = []
|
| 28 |
-
for k in ("text", "description", "body", "content", "name"):
|
| 29 |
-
if k in item and item[k]:
|
| 30 |
-
texts.append(str(item[k]))
|
| 31 |
-
if "content_json" in item and isinstance(item["content_json"], dict):
|
| 32 |
-
for v in item["content_json"].values():
|
| 33 |
-
if isinstance(v, str) and v.strip():
|
| 34 |
-
texts.append(v)
|
| 35 |
-
return texts
|
| 36 |
-
|
| 37 |
-
def _chunk_text(text, max_words=80):
|
| 38 |
-
sentences = re.split(r'(?<=[.!?]) +', text)
|
| 39 |
-
chunks, cur, count = [], [], 0
|
| 40 |
-
for s in sentences:
|
| 41 |
-
words = s.split()
|
| 42 |
-
if len(words) < 5:
|
| 43 |
-
continue
|
| 44 |
-
if count + len(words) > max_words and cur:
|
| 45 |
-
chunks.append(" ".join(cur))
|
| 46 |
-
cur, count = [s], len(words)
|
| 47 |
-
else:
|
| 48 |
-
cur.append(s)
|
| 49 |
-
count += len(words)
|
| 50 |
-
if cur:
|
| 51 |
-
chunks.append(" ".join(cur))
|
| 52 |
-
return chunks
|
| 53 |
-
|
| 54 |
def _build_index():
|
| 55 |
print("🔄 Building index...")
|
| 56 |
-
content =
|
| 57 |
-
resources =
|
| 58 |
|
| 59 |
-
chunks = []
|
| 60 |
-
chunk_to_doc_idx = []
|
| 61 |
-
documents = []
|
| 62 |
-
metadata = [] # will store dict with name, id, dates
|
| 63 |
|
| 64 |
for item in content + resources:
|
| 65 |
-
|
| 66 |
-
full_text = "\n".join(_extract_text(item))
|
| 67 |
if not full_text.strip():
|
| 68 |
continue
|
| 69 |
|
| 70 |
doc_idx = len(documents)
|
| 71 |
documents.append(full_text)
|
| 72 |
|
| 73 |
-
# --- Metadata ---
|
| 74 |
meta = {
|
| 75 |
"document_id": item.get("document_id"),
|
| 76 |
"name": item.get("name"),
|
| 77 |
"create_date": item.get("create_date"),
|
| 78 |
"publish_date": item.get("publish_date"),
|
| 79 |
-
"categories": item.get("categories")
|
| 80 |
}
|
| 81 |
metadata.append(meta)
|
| 82 |
|
| 83 |
-
|
| 84 |
-
for ch in _chunk_text(full_text):
|
| 85 |
chunks.append(ch)
|
| 86 |
chunk_to_doc_idx.append(doc_idx)
|
| 87 |
|
|
@@ -110,14 +73,14 @@ def _load_or_build():
|
|
| 110 |
metadata = pickle.load(f)
|
| 111 |
return chunks, chunk_to_doc_idx, documents, metadata, embeddings
|
| 112 |
|
|
|
|
| 113 |
chunks, chunk_to_doc_idx, documents, metadata, embeddings = _load_or_build()
|
| 114 |
|
| 115 |
-
# ---------
|
| 116 |
def retrieve(query, top_k=5):
|
| 117 |
q_emb = model.encode([query], convert_to_numpy=True, normalize_embeddings=True)
|
| 118 |
scores = (embeddings @ q_emb.T).squeeze()
|
| 119 |
|
| 120 |
-
# Aggregate: pick the max scoring chunk per document
|
| 121 |
doc_best = defaultdict(lambda: (-np.inf, None)) # (score, best_snippet)
|
| 122 |
for idx, sc in enumerate(scores):
|
| 123 |
doc_id = chunk_to_doc_idx[idx]
|
|
|
|
| 4 |
from sentence_transformers import SentenceTransformer
|
| 5 |
from collections import defaultdict
|
| 6 |
|
| 7 |
+
from preprocess import load_json, extract_text, chunk_text
|
| 8 |
+
|
| 9 |
+
# --------- Paths ---------
|
| 10 |
DATA_DIR = Path("data")
|
| 11 |
DATA_DIR.mkdir(exist_ok=True)
|
| 12 |
EMB_FILE = DATA_DIR / "embeddings.npy"
|
|
|
|
| 16 |
CONTENT_FILE = DATA_DIR / "Toolkit_Content_results.json"
|
| 17 |
RESOURCES_FILE = DATA_DIR / "Toolkit_Resources_results.json"
|
| 18 |
|
| 19 |
+
# Embedding model
|
| 20 |
model = SentenceTransformer("all-MiniLM-L6-v2")
|
| 21 |
|
| 22 |
+
# --------- Build Index ---------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
def _build_index():
|
| 24 |
print("🔄 Building index...")
|
| 25 |
+
content = load_json("Toolkit_Content_results.json")
|
| 26 |
+
resources = load_json("Toolkit_Resources_results.json")
|
| 27 |
|
| 28 |
+
chunks, chunk_to_doc_idx, documents, metadata = [], [], [], []
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
for item in content + resources:
|
| 31 |
+
full_text = "\n".join(extract_text(item))
|
|
|
|
| 32 |
if not full_text.strip():
|
| 33 |
continue
|
| 34 |
|
| 35 |
doc_idx = len(documents)
|
| 36 |
documents.append(full_text)
|
| 37 |
|
|
|
|
| 38 |
meta = {
|
| 39 |
"document_id": item.get("document_id"),
|
| 40 |
"name": item.get("name"),
|
| 41 |
"create_date": item.get("create_date"),
|
| 42 |
"publish_date": item.get("publish_date"),
|
| 43 |
+
"categories": item.get("categories"),
|
| 44 |
}
|
| 45 |
metadata.append(meta)
|
| 46 |
|
| 47 |
+
for ch in chunk_text(full_text):
|
|
|
|
| 48 |
chunks.append(ch)
|
| 49 |
chunk_to_doc_idx.append(doc_idx)
|
| 50 |
|
|
|
|
| 73 |
metadata = pickle.load(f)
|
| 74 |
return chunks, chunk_to_doc_idx, documents, metadata, embeddings
|
| 75 |
|
| 76 |
+
# Load on import
|
| 77 |
chunks, chunk_to_doc_idx, documents, metadata, embeddings = _load_or_build()
|
| 78 |
|
| 79 |
+
# --------- Retrieval ---------
|
| 80 |
def retrieve(query, top_k=5):
|
| 81 |
q_emb = model.encode([query], convert_to_numpy=True, normalize_embeddings=True)
|
| 82 |
scores = (embeddings @ q_emb.T).squeeze()
|
| 83 |
|
|
|
|
| 84 |
doc_best = defaultdict(lambda: (-np.inf, None)) # (score, best_snippet)
|
| 85 |
for idx, sc in enumerate(scores):
|
| 86 |
doc_id = chunk_to_doc_idx[idx]
|