j-js commited on
Commit
a214825
·
verified ·
1 Parent(s): 75970da

Update retrieval_engine.py

Browse files
Files changed (1) hide show
  1. retrieval_engine.py +85 -18
retrieval_engine.py CHANGED
@@ -2,7 +2,7 @@ from __future__ import annotations
2
 
3
  import json
4
  import os
5
- from typing import List, Optional
6
 
7
  from models import RetrievedChunk
8
  from utils import clean_math_text, score_token_overlap
@@ -24,18 +24,24 @@ class RetrievalEngine:
24
  self.rows = self._load_rows(data_path)
25
  self.encoder = None
26
  self.embeddings = None
 
27
  if SentenceTransformer is not None and self.rows:
28
  try:
29
  self.encoder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
30
- self.embeddings = self.encoder.encode([r["text"] for r in self.rows], convert_to_numpy=True, normalize_embeddings=True)
 
 
 
 
31
  except Exception:
32
  self.encoder = None
33
  self.embeddings = None
34
 
35
- def _load_rows(self, data_path: str):
36
- rows = []
37
  if not os.path.exists(data_path):
38
  return rows
 
39
  with open(data_path, "r", encoding="utf-8") as f:
40
  for line in f:
41
  line = line.strip()
@@ -45,43 +51,95 @@ class RetrievalEngine:
45
  item = json.loads(line)
46
  except Exception:
47
  continue
48
- rows.append({
49
- "text": item.get("text", ""),
50
- "topic": item.get("topic", item.get("section", "general")) or "general",
51
- "source": item.get("source", "local_corpus"),
52
- })
 
 
 
53
  return rows
54
 
55
  def _topic_bonus(self, desired_topic: str, row_topic: str, intent: str) -> float:
56
  desired_topic = (desired_topic or "").lower()
57
  row_topic = (row_topic or "").lower()
58
  intent = (intent or "").lower()
 
59
  bonus = 0.0
 
60
  if desired_topic and desired_topic in row_topic:
61
  bonus += 1.25
 
62
  if desired_topic == "algebra" and row_topic in {"algebra", "linear equations", "equations"}:
63
  bonus += 1.0
 
64
  if desired_topic == "percent" and "percent" in row_topic:
65
  bonus += 1.0
66
- if intent in {"method", "step_by_step", "full_working", "hint"}:
67
- if any(k in row_topic for k in ["algebra", "percent", "fractions", "word_problems", "general"]):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  bonus += 0.25
 
69
  return bonus
70
 
71
- def search(self, query: str, topic: str = "", intent: str = "answer", k: int = 3) -> List[RetrievedChunk]:
 
 
 
 
 
 
72
  if not self.rows:
73
  return []
74
- combined_query = clean_math_text(query)
75
 
 
76
  scores = []
 
77
  if self.encoder is not None and self.embeddings is not None and np is not None:
78
  try:
79
- q = self.encoder.encode([combined_query], convert_to_numpy=True, normalize_embeddings=True)[0]
 
 
 
 
80
  semantic_scores = self.embeddings @ q
 
81
  for row, sem in zip(self.rows, semantic_scores.tolist()):
82
  lexical = score_token_overlap(combined_query, row["text"])
83
  bonus = self._topic_bonus(topic, row["topic"], intent)
84
- scores.append((0.7 * sem + 0.3 * lexical + bonus, row))
 
85
  except Exception:
86
  scores = []
87
 
@@ -92,7 +150,16 @@ class RetrievalEngine:
92
  scores.append((lexical + bonus, row))
93
 
94
  scores.sort(key=lambda x: x[0], reverse=True)
95
- results = []
 
96
  for score, row in scores[:k]:
97
- results.append(RetrievedChunk(text=row["text"], topic=row["topic"], source=row["source"], score=float(score)))
98
- return results
 
 
 
 
 
 
 
 
 
2
 
3
  import json
4
  import os
5
+ from typing import List
6
 
7
  from models import RetrievedChunk
8
  from utils import clean_math_text, score_token_overlap
 
24
  self.rows = self._load_rows(data_path)
25
  self.encoder = None
26
  self.embeddings = None
27
+
28
  if SentenceTransformer is not None and self.rows:
29
  try:
30
  self.encoder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
31
+ self.embeddings = self.encoder.encode(
32
+ [r["text"] for r in self.rows],
33
+ convert_to_numpy=True,
34
+ normalize_embeddings=True,
35
+ )
36
  except Exception:
37
  self.encoder = None
38
  self.embeddings = None
39
 
40
+ def _load_rows(self, data_path: str) -> List[dict]:
41
+ rows: List[dict] = []
42
  if not os.path.exists(data_path):
43
  return rows
44
+
45
  with open(data_path, "r", encoding="utf-8") as f:
46
  for line in f:
47
  line = line.strip()
 
51
  item = json.loads(line)
52
  except Exception:
53
  continue
54
+
55
+ rows.append(
56
+ {
57
+ "text": item.get("text", ""),
58
+ "topic": item.get("topic", item.get("section", "general")) or "general",
59
+ "source": item.get("source", "local_corpus"),
60
+ }
61
+ )
62
  return rows
63
 
64
  def _topic_bonus(self, desired_topic: str, row_topic: str, intent: str) -> float:
65
  desired_topic = (desired_topic or "").lower()
66
  row_topic = (row_topic or "").lower()
67
  intent = (intent or "").lower()
68
+
69
  bonus = 0.0
70
+
71
  if desired_topic and desired_topic in row_topic:
72
  bonus += 1.25
73
+
74
  if desired_topic == "algebra" and row_topic in {"algebra", "linear equations", "equations"}:
75
  bonus += 1.0
76
+
77
  if desired_topic == "percent" and "percent" in row_topic:
78
  bonus += 1.0
79
+
80
+ if desired_topic in {"number_theory", "number_properties"} and any(
81
+ k in row_topic for k in ["number", "divisible", "remainder", "prime", "factor"]
82
+ ):
83
+ bonus += 1.0
84
+
85
+ if desired_topic == "geometry" and any(
86
+ k in row_topic for k in ["geometry", "circle", "triangle", "area", "perimeter"]
87
+ ):
88
+ bonus += 1.0
89
+
90
+ if desired_topic == "probability" and "probability" in row_topic:
91
+ bonus += 1.0
92
+
93
+ if desired_topic == "statistics" and any(
94
+ k in row_topic for k in ["statistics", "mean", "median", "average", "distribution"]
95
+ ):
96
+ bonus += 1.0
97
+
98
+ if intent in {"method", "step_by_step", "full_working", "hint", "walkthrough", "instruction"}:
99
+ if any(
100
+ k in row_topic
101
+ for k in [
102
+ "algebra",
103
+ "percent",
104
+ "fractions",
105
+ "word_problems",
106
+ "general",
107
+ "ratio",
108
+ "probability",
109
+ "statistics",
110
+ ]
111
+ ):
112
  bonus += 0.25
113
+
114
  return bonus
115
 
116
+ def search(
117
+ self,
118
+ query: str,
119
+ topic: str = "",
120
+ intent: str = "answer",
121
+ k: int = 3,
122
+ ) -> List[RetrievedChunk]:
123
  if not self.rows:
124
  return []
 
125
 
126
+ combined_query = clean_math_text(query)
127
  scores = []
128
+
129
  if self.encoder is not None and self.embeddings is not None and np is not None:
130
  try:
131
+ q = self.encoder.encode(
132
+ [combined_query],
133
+ convert_to_numpy=True,
134
+ normalize_embeddings=True,
135
+ )[0]
136
  semantic_scores = self.embeddings @ q
137
+
138
  for row, sem in zip(self.rows, semantic_scores.tolist()):
139
  lexical = score_token_overlap(combined_query, row["text"])
140
  bonus = self._topic_bonus(topic, row["topic"], intent)
141
+ total = 0.7 * sem + 0.3 * lexical + bonus
142
+ scores.append((total, row))
143
  except Exception:
144
  scores = []
145
 
 
150
  scores.append((lexical + bonus, row))
151
 
152
  scores.sort(key=lambda x: x[0], reverse=True)
153
+
154
+ results: List[RetrievedChunk] = []
155
  for score, row in scores[:k]:
156
+ results.append(
157
+ RetrievedChunk(
158
+ text=row["text"],
159
+ topic=row["topic"],
160
+ source=row["source"],
161
+ score=float(score),
162
+ )
163
+ )
164
+
165
+ return results