Harshit Jain commited on
Commit
b42a15f
·
1 Parent(s): 597fb81

Added spellcheck module + 'Did you mean' suggestions with custom vocab

Browse files
Files changed (5) hide show
  1. app.py +8 -2
  2. data/embeddings.npy +1 -1
  3. requirements.txt +1 -0
  4. retriever.py +8 -1
  5. spellcheck.py +35 -0
app.py CHANGED
@@ -3,6 +3,7 @@ from dash import dcc, html
3
  from dash.dependencies import Input, Output, State, MATCH
4
  import dash_bootstrap_components as dbc
5
  from retriever import retrieve
 
6
 
7
  app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])
8
 
@@ -38,11 +39,16 @@ def search_callback(_, query):
38
  if not query:
39
  return dbc.Alert("Please enter a query.", color="warning")
40
 
41
- results = retrieve(query, top_k=5)
 
 
 
 
 
 
42
  if not results:
43
  return dbc.Alert("No results found.", color="danger")
44
 
45
- cards = []
46
  for idx, r in enumerate(results):
47
  meta = r["metadata"]
48
  title = meta.get("name") or f"Document {meta.get('document_id')}"
 
3
  from dash.dependencies import Input, Output, State, MATCH
4
  import dash_bootstrap_components as dbc
5
  from retriever import retrieve
6
+ from spellcheck import autocorrect_query
7
 
8
  app = dash.Dash(__name__, external_stylesheets=[dbc.themes.BOOTSTRAP])
9
 
 
39
  if not query:
40
  return dbc.Alert("Please enter a query.", color="warning")
41
 
42
+ corrected_query, suggestion = autocorrect_query(query)
43
+ results = retrieve(corrected_query, top_k=5)
44
+
45
+ cards = []
46
+ if suggestion:
47
+ cards.append(dbc.Alert(suggestion, color="info"))
48
+
49
  if not results:
50
  return dbc.Alert("No results found.", color="danger")
51
 
 
52
  for idx, r in enumerate(results):
53
  meta = r["metadata"]
54
  title = meta.get("name") or f"Document {meta.get('document_id')}"
data/embeddings.npy CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:bad33642819fb65ff5d780c1fe8ea60974a03a95218ace76d2e4f7a6a95c8577
3
  size 1313408
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cc7063629c99f9ba0cb1063dc07a70706c82c217864ceb9b44c5d8e64f96967d
3
  size 1313408
requirements.txt CHANGED
@@ -3,3 +3,4 @@ dash-bootstrap-components==1.6.0
3
  sentence-transformers>=2.6.0
4
  numpy>=1.23
5
  huggingface-hub>=0.23.0
 
 
3
  sentence-transformers>=2.6.0
4
  numpy>=1.23
5
  huggingface-hub>=0.23.0
6
+ pyspellchecker
retriever.py CHANGED
@@ -5,6 +5,7 @@ from sentence_transformers import SentenceTransformer
5
  from collections import defaultdict
6
 
7
  from preprocess import load_json, extract_text, chunk_text
 
8
 
9
  # --------- Paths ---------
10
  DATA_DIR = Path("data")
@@ -64,13 +65,14 @@ def _load_or_build():
64
  if not (EMB_FILE.exists() and CHUNK_FILE.exists() and DOC_FILE.exists() and META_FILE.exists()):
65
  _build_index()
66
  print("🔄 Loading precomputed data...")
67
- embeddings = np.load(EMB_FILE)
68
  with open(CHUNK_FILE, "rb") as f:
69
  chunks, chunk_to_doc_idx = pickle.load(f)
70
  with open(DOC_FILE, "rb") as f:
71
  documents = pickle.load(f)
72
  with open(META_FILE, "rb") as f:
73
  metadata = pickle.load(f)
 
74
  return chunks, chunk_to_doc_idx, documents, metadata, embeddings
75
 
76
  # Load on import
@@ -78,6 +80,11 @@ chunks, chunk_to_doc_idx, documents, metadata, embeddings = _load_or_build()
78
 
79
  # --------- Retrieval ---------
80
  def retrieve(query, top_k=5):
 
 
 
 
 
81
  q_emb = model.encode([query], convert_to_numpy=True, normalize_embeddings=True)
82
  scores = (embeddings @ q_emb.T).squeeze()
83
 
 
5
  from collections import defaultdict
6
 
7
  from preprocess import load_json, extract_text, chunk_text
8
+ from spellcheck import autocorrect_query, load_custom_vocab
9
 
10
  # --------- Paths ---------
11
  DATA_DIR = Path("data")
 
65
  if not (EMB_FILE.exists() and CHUNK_FILE.exists() and DOC_FILE.exists() and META_FILE.exists()):
66
  _build_index()
67
  print("🔄 Loading precomputed data...")
68
+ embeddings = np.load(EMB_FILE, allow_pickle=True)
69
  with open(CHUNK_FILE, "rb") as f:
70
  chunks, chunk_to_doc_idx = pickle.load(f)
71
  with open(DOC_FILE, "rb") as f:
72
  documents = pickle.load(f)
73
  with open(META_FILE, "rb") as f:
74
  metadata = pickle.load(f)
75
+ load_custom_vocab(documents)
76
  return chunks, chunk_to_doc_idx, documents, metadata, embeddings
77
 
78
  # Load on import
 
80
 
81
  # --------- Retrieval ---------
82
  def retrieve(query, top_k=5):
83
+ # Autocorrect step
84
+ query, suggestion = autocorrect_query(query)
85
+ if suggestion:
86
+ print(suggestion) # Logs correction suggestion in console
87
+
88
  q_emb = model.encode([query], convert_to_numpy=True, normalize_embeddings=True)
89
  scores = (embeddings @ q_emb.T).squeeze()
90
 
spellcheck.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # spellcheck.py
2
+ from spellchecker import SpellChecker
3
+
4
+ spell = SpellChecker()
5
+ custom_vocab = set()
6
+
7
+ def load_custom_vocab(docs: list[str]):
8
+ """Load custom vocab from project documents."""
9
+ global custom_vocab
10
+ for d in docs:
11
+ for w in d.split():
12
+ custom_vocab.add(w.lower())
13
+
14
+ def autocorrect_query(query: str) -> tuple[str, str]:
15
+ """
16
+ Autocorrects a query and returns (corrected_query, suggestion).
17
+ Suggestion will be a 'Did you mean...' string or "" if no change.
18
+ """
19
+ words = query.split()
20
+ corrected = []
21
+
22
+ for w in words:
23
+ candidates = spell.candidates(w)
24
+ # Prefer candidates from custom vocab if available
25
+ custom_candidates = [c for c in candidates if c in custom_vocab]
26
+ if custom_candidates:
27
+ corrected.append(custom_candidates[0])
28
+ else:
29
+ corrected.append(spell.correction(w) or w)
30
+
31
+ corrected_query = " ".join(corrected)
32
+ if corrected_query != query:
33
+ return corrected_query, f"Did you mean '{corrected_query}'?"
34
+ return query, ""
35
+