Anirudha Soni commited on
Commit
597fb81
·
1 Parent(s): fbc59bf

Basic changes

Browse files
Files changed (3) hide show
  1. app.py +36 -22
  2. preprocess.py +19 -36
  3. retriever.py +13 -50
app.py CHANGED
@@ -11,24 +11,28 @@ app.layout = dbc.Container(
11
  html.H1("Toolkit Document Search", className="mb-4"),
12
  dbc.Row([
13
  dbc.Col([
14
- dcc.Input(id="query-input", type="text",
15
- placeholder="Type your question...", debounce=True,
16
- style={"width": "100%", "padding": "10px"}),
 
 
 
 
17
  html.Br(), html.Br(),
18
- dbc.Button("Search", id="search-btn", color="primary")
19
- ], md=8)
20
  ]),
21
  html.Hr(),
22
- html.Div(id="results-area")
23
  ],
24
- fluid=True
25
  )
26
 
27
  @app.callback(
28
  Output("results-area", "children"),
29
  Input("search-btn", "n_clicks"),
30
  State("query-input", "value"),
31
- prevent_initial_call=True
32
  )
33
  def search_callback(_, query):
34
  if not query:
@@ -42,7 +46,11 @@ def search_callback(_, query):
42
  for idx, r in enumerate(results):
43
  meta = r["metadata"]
44
  title = meta.get("name") or f"Document {meta.get('document_id')}"
45
- subtitle = f"Created: {meta.get('create_date') or 'N/A'} | Published: {meta.get('publish_date') or 'N/A'} | Categories: {', '.join(meta.get('categories') or [])}"
 
 
 
 
46
  cards.append(
47
  dbc.Card(
48
  [
@@ -55,26 +63,32 @@ def search_callback(_, query):
55
  id={"type": "collapse-btn", "index": idx},
56
  color="link",
57
  n_clicks=0,
58
- style={"float": "right"}
 
 
 
 
59
  ),
60
- html.Div(f"Score: {r['score']:.3f}", style={"float": "right", "marginRight": "1em"})
61
  ],
62
- style={"display": "flex", "flexDirection": "column"}
63
- ),
64
- dbc.CardBody(
65
- html.P(r["snippet"], style={"fontStyle": "italic"})
66
  ),
 
67
  dbc.Collapse(
68
  dbc.CardBody(
69
- html.Pre(r["full_text"], style={"whiteSpace": "pre-wrap",
70
- "maxHeight": "300px",
71
- "overflowY": "auto"})
 
 
 
 
 
72
  ),
73
  id={"type": "collapse", "index": idx},
74
- is_open=False
75
  ),
76
  ],
77
- className="mb-3"
78
  )
79
  )
80
  return cards
@@ -82,7 +96,7 @@ def search_callback(_, query):
82
  @app.callback(
83
  Output({"type": "collapse", "index": MATCH}, "is_open"),
84
  Input({"type": "collapse-btn", "index": MATCH}, "n_clicks"),
85
- State({"type": "collapse", "index": MATCH}, "is_open")
86
  )
87
  def toggle_collapse(n, is_open):
88
  if n:
@@ -92,4 +106,4 @@ def toggle_collapse(n, is_open):
92
  if __name__ == "__main__":
93
  import os
94
  port = int(os.environ.get("PORT", 7860))
95
- app.run_server(host="0.0.0.0", port=port, debug=False)
 
11
  html.H1("Toolkit Document Search", className="mb-4"),
12
  dbc.Row([
13
  dbc.Col([
14
+ dcc.Input(
15
+ id="query-input",
16
+ type="text",
17
+ placeholder="Type your question...",
18
+ debounce=True,
19
+ style={"width": "100%", "padding": "10px"},
20
+ ),
21
  html.Br(), html.Br(),
22
+ dbc.Button("Search", id="search-btn", color="primary"),
23
+ ], md=8),
24
  ]),
25
  html.Hr(),
26
+ html.Div(id="results-area"),
27
  ],
28
+ fluid=True,
29
  )
30
 
31
  @app.callback(
32
  Output("results-area", "children"),
33
  Input("search-btn", "n_clicks"),
34
  State("query-input", "value"),
35
+ prevent_initial_call=True,
36
  )
37
  def search_callback(_, query):
38
  if not query:
 
46
  for idx, r in enumerate(results):
47
  meta = r["metadata"]
48
  title = meta.get("name") or f"Document {meta.get('document_id')}"
49
+ subtitle = (
50
+ f"Created: {meta.get('create_date') or 'N/A'} | "
51
+ f"Published: {meta.get('publish_date') or 'N/A'} | "
52
+ f"Categories: {', '.join(meta.get('categories') or [])}"
53
+ )
54
  cards.append(
55
  dbc.Card(
56
  [
 
63
  id={"type": "collapse-btn", "index": idx},
64
  color="link",
65
  n_clicks=0,
66
+ style={"float": "right"},
67
+ ),
68
+ html.Div(
69
+ f"Score: {r['score']:.3f}",
70
+ style={"float": "right", "marginRight": "1em"},
71
  ),
 
72
  ],
73
+ style={"display": "flex", "flexDirection": "column"},
 
 
 
74
  ),
75
+ dbc.CardBody(html.P(r["snippet"], style={"fontStyle": "italic"})),
76
  dbc.Collapse(
77
  dbc.CardBody(
78
+ html.Pre(
79
+ r["full_text"],
80
+ style={
81
+ "whiteSpace": "pre-wrap",
82
+ "maxHeight": "300px",
83
+ "overflowY": "auto",
84
+ },
85
+ )
86
  ),
87
  id={"type": "collapse", "index": idx},
88
+ is_open=False,
89
  ),
90
  ],
91
+ className="mb-3",
92
  )
93
  )
94
  return cards
 
96
  @app.callback(
97
  Output({"type": "collapse", "index": MATCH}, "is_open"),
98
  Input({"type": "collapse-btn", "index": MATCH}, "n_clicks"),
99
+ State({"type": "collapse", "index": MATCH}, "is_open"),
100
  )
101
  def toggle_collapse(n, is_open):
102
  if n:
 
106
  if __name__ == "__main__":
107
  import os
108
  port = int(os.environ.get("PORT", 7860))
109
+ app.run_server(host="0.0.0.0", port=port, debug=False)
preprocess.py CHANGED
@@ -1,60 +1,43 @@
1
- import json, re, pickle
2
- import numpy as np
3
- from sentence_transformers import SentenceTransformer
4
  from pathlib import Path
5
 
6
  DATA_DIR = Path("data")
7
 
8
  def load_json(filename):
9
- with open(DATA_DIR/filename, "r", encoding="utf-8") as f:
 
10
  data = json.load(f)
11
  if isinstance(data, dict) and "results" in data:
12
  return data["results"]
13
  return data if isinstance(data, list) else []
14
 
15
  def extract_text(item):
 
16
  texts = []
17
- if isinstance(item, dict):
18
- for k in ("text", "description", "body", "content", "name"):
19
- if k in item and item[k]:
20
- texts.append(str(item[k]))
21
- if "content_json" in item and isinstance(item["content_json"], dict):
22
- for v in item["content_json"].values():
23
- if isinstance(v, str) and v.strip():
24
- texts.append(v)
25
  return texts
26
 
27
  def chunk_text(text, max_words=80):
 
28
  sentences = re.split(r'(?<=[.!?]) +', text)
29
  chunks, cur, count = [], [], 0
30
  for s in sentences:
31
  words = s.split()
32
- if len(words) < 5: continue
 
33
  if count + len(words) > max_words and cur:
34
  chunks.append(" ".join(cur))
35
  cur, count = [s], len(words)
36
  else:
37
- cur.append(s); count += len(words)
38
- if cur: chunks.append(" ".join(cur))
 
 
39
  return chunks
40
-
41
- print("🔄 Loading JSON...")
42
- content = load_json("Toolkit_Content_results.json")
43
- resources = load_json("Toolkit_Resources_results.json")
44
-
45
- docs = []
46
- for item in content + resources:
47
- for t in extract_text(item):
48
- docs.extend(chunk_text(t))
49
- print(f"✅ Loaded {len(docs)} chunks")
50
-
51
- print("🔄 Encoding with SentenceTransformer...")
52
- model = SentenceTransformer("all-MiniLM-L6-v2")
53
- embeddings = model.encode(docs, convert_to_numpy=True, normalize_embeddings=True)
54
-
55
- # Save
56
- print("💾 Saving artifacts...")
57
- np.save(DATA_DIR/"embeddings.npy", embeddings)
58
- with open(DATA_DIR/"docs.pkl", "wb") as f:
59
- pickle.dump(docs, f)
60
- print("✅ Done!")
 
1
+ import json
2
+ import re
 
3
  from pathlib import Path
4
 
5
  DATA_DIR = Path("data")
6
 
7
  def load_json(filename):
8
+ """Load a JSON file and return list of records."""
9
+ with open(DATA_DIR / filename, "r", encoding="utf-8") as f:
10
  data = json.load(f)
11
  if isinstance(data, dict) and "results" in data:
12
  return data["results"]
13
  return data if isinstance(data, list) else []
14
 
15
  def extract_text(item):
16
+ """Extract textual fields from a JSON record."""
17
  texts = []
18
+ for k in ("text", "description", "body", "content", "name"):
19
+ if k in item and item[k]:
20
+ texts.append(str(item[k]))
21
+ if "content_json" in item and isinstance(item["content_json"], dict):
22
+ for v in item["content_json"].values():
23
+ if isinstance(v, str) and v.strip():
24
+ texts.append(v)
 
25
  return texts
26
 
27
  def chunk_text(text, max_words=80):
28
+ """Split long text into smaller chunks."""
29
  sentences = re.split(r'(?<=[.!?]) +', text)
30
  chunks, cur, count = [], [], 0
31
  for s in sentences:
32
  words = s.split()
33
+ if len(words) < 5:
34
+ continue
35
  if count + len(words) > max_words and cur:
36
  chunks.append(" ".join(cur))
37
  cur, count = [s], len(words)
38
  else:
39
+ cur.append(s)
40
+ count += len(words)
41
+ if cur:
42
+ chunks.append(" ".join(cur))
43
  return chunks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
retriever.py CHANGED
@@ -4,6 +4,9 @@ import numpy as np
4
  from sentence_transformers import SentenceTransformer
5
  from collections import defaultdict
6
 
 
 
 
7
  DATA_DIR = Path("data")
8
  DATA_DIR.mkdir(exist_ok=True)
9
  EMB_FILE = DATA_DIR / "embeddings.npy"
@@ -13,75 +16,35 @@ META_FILE = DATA_DIR / "metadata.pkl" # metadata for each doc
13
  CONTENT_FILE = DATA_DIR / "Toolkit_Content_results.json"
14
  RESOURCES_FILE = DATA_DIR / "Toolkit_Resources_results.json"
15
 
 
16
  model = SentenceTransformer("all-MiniLM-L6-v2")
17
 
18
- # ---------- Helpers ----------
19
- def _load_json(filename):
20
- with open(filename, "r", encoding="utf-8") as f:
21
- data = json.load(f)
22
- if isinstance(data, dict) and "results" in data:
23
- return data["results"]
24
- return data if isinstance(data, list) else []
25
-
26
- def _extract_text(item):
27
- texts = []
28
- for k in ("text", "description", "body", "content", "name"):
29
- if k in item and item[k]:
30
- texts.append(str(item[k]))
31
- if "content_json" in item and isinstance(item["content_json"], dict):
32
- for v in item["content_json"].values():
33
- if isinstance(v, str) and v.strip():
34
- texts.append(v)
35
- return texts
36
-
37
- def _chunk_text(text, max_words=80):
38
- sentences = re.split(r'(?<=[.!?]) +', text)
39
- chunks, cur, count = [], [], 0
40
- for s in sentences:
41
- words = s.split()
42
- if len(words) < 5:
43
- continue
44
- if count + len(words) > max_words and cur:
45
- chunks.append(" ".join(cur))
46
- cur, count = [s], len(words)
47
- else:
48
- cur.append(s)
49
- count += len(words)
50
- if cur:
51
- chunks.append(" ".join(cur))
52
- return chunks
53
-
54
  def _build_index():
55
  print("🔄 Building index...")
56
- content = _load_json(CONTENT_FILE)
57
- resources = _load_json(RESOURCES_FILE)
58
 
59
- chunks = []
60
- chunk_to_doc_idx = []
61
- documents = []
62
- metadata = [] # will store dict with name, id, dates
63
 
64
  for item in content + resources:
65
- # Combine all text for embeddings
66
- full_text = "\n".join(_extract_text(item))
67
  if not full_text.strip():
68
  continue
69
 
70
  doc_idx = len(documents)
71
  documents.append(full_text)
72
 
73
- # --- Metadata ---
74
  meta = {
75
  "document_id": item.get("document_id"),
76
  "name": item.get("name"),
77
  "create_date": item.get("create_date"),
78
  "publish_date": item.get("publish_date"),
79
- "categories": item.get("categories")
80
  }
81
  metadata.append(meta)
82
 
83
- # --- Chunking ---
84
- for ch in _chunk_text(full_text):
85
  chunks.append(ch)
86
  chunk_to_doc_idx.append(doc_idx)
87
 
@@ -110,14 +73,14 @@ def _load_or_build():
110
  metadata = pickle.load(f)
111
  return chunks, chunk_to_doc_idx, documents, metadata, embeddings
112
 
 
113
  chunks, chunk_to_doc_idx, documents, metadata, embeddings = _load_or_build()
114
 
115
- # ---------- Retrieval ----------
116
  def retrieve(query, top_k=5):
117
  q_emb = model.encode([query], convert_to_numpy=True, normalize_embeddings=True)
118
  scores = (embeddings @ q_emb.T).squeeze()
119
 
120
- # Aggregate: pick the max scoring chunk per document
121
  doc_best = defaultdict(lambda: (-np.inf, None)) # (score, best_snippet)
122
  for idx, sc in enumerate(scores):
123
  doc_id = chunk_to_doc_idx[idx]
 
4
  from sentence_transformers import SentenceTransformer
5
  from collections import defaultdict
6
 
7
+ from preprocess import load_json, extract_text, chunk_text
8
+
9
+ # --------- Paths ---------
10
  DATA_DIR = Path("data")
11
  DATA_DIR.mkdir(exist_ok=True)
12
  EMB_FILE = DATA_DIR / "embeddings.npy"
 
16
  CONTENT_FILE = DATA_DIR / "Toolkit_Content_results.json"
17
  RESOURCES_FILE = DATA_DIR / "Toolkit_Resources_results.json"
18
 
19
+ # Embedding model
20
  model = SentenceTransformer("all-MiniLM-L6-v2")
21
 
22
+ # --------- Build Index ---------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  def _build_index():
24
  print("🔄 Building index...")
25
+ content = load_json("Toolkit_Content_results.json")
26
+ resources = load_json("Toolkit_Resources_results.json")
27
 
28
+ chunks, chunk_to_doc_idx, documents, metadata = [], [], [], []
 
 
 
29
 
30
  for item in content + resources:
31
+ full_text = "\n".join(extract_text(item))
 
32
  if not full_text.strip():
33
  continue
34
 
35
  doc_idx = len(documents)
36
  documents.append(full_text)
37
 
 
38
  meta = {
39
  "document_id": item.get("document_id"),
40
  "name": item.get("name"),
41
  "create_date": item.get("create_date"),
42
  "publish_date": item.get("publish_date"),
43
+ "categories": item.get("categories"),
44
  }
45
  metadata.append(meta)
46
 
47
+ for ch in chunk_text(full_text):
 
48
  chunks.append(ch)
49
  chunk_to_doc_idx.append(doc_idx)
50
 
 
73
  metadata = pickle.load(f)
74
  return chunks, chunk_to_doc_idx, documents, metadata, embeddings
75
 
76
+ # Load on import
77
  chunks, chunk_to_doc_idx, documents, metadata, embeddings = _load_or_build()
78
 
79
+ # --------- Retrieval ---------
80
  def retrieve(query, top_k=5):
81
  q_emb = model.encode([query], convert_to_numpy=True, normalize_embeddings=True)
82
  scores = (embeddings @ q_emb.T).squeeze()
83
 
 
84
  doc_best = defaultdict(lambda: (-np.inf, None)) # (score, best_snippet)
85
  for idx, sc in enumerate(scores):
86
  doc_id = chunk_to_doc_idx[idx]