hodfa840 commited on
Commit
d624b44
·
1 Parent(s): 89b7e25

perf: reduce per-query latency — shared model, single encode, float32 LLM

Browse files

- Add modules/shared.py: single SentenceTransformer singleton loaded once
- HybridRetriever and DriftDetector now share the same model instance
instead of each loading a separate copy (saves ~90MB RAM + load time)
- Encode query embedding once in process_query and pass to both
analyze_drift() and search() — eliminates redundant encode call
- Switch LLM to torch.float32 (bfloat16 has no CPU hardware acceleration)
- Reduce max_new_tokens 120→80 for faster token generation

Files changed (5) hide show
  1. app.py +6 -2
  2. modules/drift.py +10 -15
  3. modules/llm.py +2 -2
  4. modules/retrieval.py +8 -3
  5. modules/shared.py +20 -0
app.py CHANGED
@@ -11,6 +11,7 @@ import sys
11
  import gradio as gr
12
  import plotly.graph_objects as go
13
  from modules.data_simulation import generate_catalog, get_scenarios
 
14
  from modules.retrieval import HybridRetriever
15
  from modules.drift import DriftDetector
16
  from modules.adaptation import Adapter
@@ -229,11 +230,14 @@ def process_query(query: str, history: list):
229
 
230
  logger.info("Processing query: %r", query)
231
 
 
 
 
232
  # 1. Measure drift
233
- drift_state, scores = detector.analyze_drift(query)
234
 
235
  # 2. Retrieve products (hybrid: price-filter + semantic)
236
- retrieved = retriever.search(query, top_k=4)
237
 
238
  # 3. Adapt system prompt
239
  system_prompt = adapter.adapt_prompt(drift_state)
 
11
  import gradio as gr
12
  import plotly.graph_objects as go
13
  from modules.data_simulation import generate_catalog, get_scenarios
14
+ from modules.shared import get_embedding_model
15
  from modules.retrieval import HybridRetriever
16
  from modules.drift import DriftDetector
17
  from modules.adaptation import Adapter
 
230
 
231
  logger.info("Processing query: %r", query)
232
 
233
+ # Encode query once — shared by drift detection and retrieval
234
+ query_emb = get_embedding_model().encode([query], show_progress_bar=False)[0]
235
+
236
  # 1. Measure drift
237
+ drift_state, scores = detector.analyze_drift(query, query_emb=query_emb)
238
 
239
  # 2. Retrieve products (hybrid: price-filter + semantic)
240
+ retrieved = retriever.search(query, top_k=4, query_emb=query_emb)
241
 
242
  # 3. Adapt system prompt
243
  system_prompt = adapter.adapt_prompt(drift_state)
modules/drift.py CHANGED
@@ -16,19 +16,10 @@ from dataclasses import dataclass, field
16
  from typing import Any
17
 
18
  import numpy as np
19
- from sentence_transformers import SentenceTransformer
20
-
21
- logger = logging.getLogger(__name__)
22
-
23
- # Use shared model instance across retriever & drift detector
24
- _shared_model: SentenceTransformer | None = None
25
 
 
26
 
27
- def _get_model() -> SentenceTransformer:
28
- global _shared_model
29
- if _shared_model is None:
30
- _shared_model = SentenceTransformer("all-MiniLM-L6-v2")
31
- return _shared_model
32
 
33
 
34
  @dataclass
@@ -57,7 +48,7 @@ class DriftDetector:
57
  _concept_embs: dict[str, Any] = field(default_factory=dict, repr=False)
58
 
59
  def __post_init__(self) -> None:
60
- model = _get_model()
61
  # Multiple anchor phrases per concept → averaged embedding for robustness
62
  concept_phrases = {
63
  "price_sensitive": [
@@ -95,13 +86,17 @@ class DriftDetector:
95
  self._ewma[c] = 0.15
96
  # ── Public API ──────────────────────────────────────────────────────────
97
 
98
- def analyze_drift(self, query: str) -> tuple[str, dict[str, float]]:
 
 
99
  """
100
  Score *query* against all concept anchors and return
101
  ``(dominant_concept, raw_scores)``.
 
 
102
  """
103
- model = _get_model()
104
- query_emb = model.encode([query], show_progress_bar=False)[0]
105
 
106
  raw_scores: dict[str, float] = {}
107
  for concept, ref_emb in self._concept_embs.items():
 
16
  from typing import Any
17
 
18
  import numpy as np
 
 
 
 
 
 
19
 
20
+ from modules.shared import get_embedding_model
21
 
22
+ logger = logging.getLogger(__name__)
 
 
 
 
23
 
24
 
25
  @dataclass
 
48
  _concept_embs: dict[str, Any] = field(default_factory=dict, repr=False)
49
 
50
  def __post_init__(self) -> None:
51
+ model = get_embedding_model()
52
  # Multiple anchor phrases per concept → averaged embedding for robustness
53
  concept_phrases = {
54
  "price_sensitive": [
 
86
  self._ewma[c] = 0.15
87
  # ── Public API ──────────────────────────────────────────────────────────
88
 
89
+ def analyze_drift(
90
+ self, query: str, query_emb=None
91
+ ) -> tuple[str, dict[str, float]]:
92
  """
93
  Score *query* against all concept anchors and return
94
  ``(dominant_concept, raw_scores)``.
95
+
96
+ Pass *query_emb* to skip re-encoding when the caller already has it.
97
  """
98
+ if query_emb is None:
99
+ query_emb = get_embedding_model().encode([query], show_progress_bar=False)[0]
100
 
101
  raw_scores: dict[str, float] = {}
102
  for concept, ref_emb in self._concept_embs.items():
modules/llm.py CHANGED
@@ -30,7 +30,7 @@ def _get_pipeline():
30
  "text-generation",
31
  model="Qwen/Qwen2.5-0.5B-Instruct",
32
  device="cpu",
33
- torch_dtype=torch.bfloat16,
34
  )
35
  logger.info("Model loaded in %.1fs", time.time() - t0)
36
  return _generator
@@ -81,7 +81,7 @@ def generate_response(
81
  gen = _get_pipeline()
82
  result = gen(
83
  messages,
84
- max_new_tokens=120,
85
  do_sample=False,
86
  return_full_text=False,
87
  )
 
30
  "text-generation",
31
  model="Qwen/Qwen2.5-0.5B-Instruct",
32
  device="cpu",
33
+ torch_dtype=torch.float32,
34
  )
35
  logger.info("Model loaded in %.1fs", time.time() - t0)
36
  return _generator
 
81
  gen = _get_pipeline()
82
  result = gen(
83
  messages,
84
+ max_new_tokens=80,
85
  do_sample=False,
86
  return_full_text=False,
87
  )
modules/retrieval.py CHANGED
@@ -13,7 +13,8 @@ import re
13
  from typing import Any
14
 
15
  import numpy as np
16
- from sentence_transformers import SentenceTransformer
 
17
 
18
  logger = logging.getLogger(__name__)
19
 
@@ -23,7 +24,7 @@ class HybridRetriever:
23
 
24
  def __init__(self, catalog: list[dict]) -> None:
25
  self.catalog = catalog
26
- self.model = SentenceTransformer("all-MiniLM-L6-v2")
27
 
28
  # Build rich embedding texts that capture all searchable facets
29
  texts = [
@@ -47,6 +48,7 @@ class HybridRetriever:
47
  query: str,
48
  top_k: int = 4,
49
  category_filter: str | None = None,
 
50
  ) -> list[dict[str, Any]]:
51
  """
52
  Retrieve top-k products for *query*.
@@ -56,6 +58,8 @@ class HybridRetriever:
56
  2. Pre-filter catalog by price / category if applicable.
57
  3. Rank remaining items by cosine similarity.
58
  4. Return top-k with scores.
 
 
59
  """
60
  price_cap = self._extract_price_cap(query)
61
  cat_hint = category_filter or self._extract_category_hint(query)
@@ -64,7 +68,8 @@ class HybridRetriever:
64
  candidate_indices = self._prefilter(price_cap, cat_hint)
65
 
66
  # Stage 2 — semantic ranking over candidates
67
- query_emb = self.model.encode([query], show_progress_bar=False)[0]
 
68
  query_norm = np.linalg.norm(query_emb)
69
 
70
  if len(candidate_indices) == 0:
 
13
  from typing import Any
14
 
15
  import numpy as np
16
+
17
+ from modules.shared import get_embedding_model
18
 
19
  logger = logging.getLogger(__name__)
20
 
 
24
 
25
  def __init__(self, catalog: list[dict]) -> None:
26
  self.catalog = catalog
27
+ self.model = get_embedding_model()
28
 
29
  # Build rich embedding texts that capture all searchable facets
30
  texts = [
 
48
  query: str,
49
  top_k: int = 4,
50
  category_filter: str | None = None,
51
+ query_emb=None,
52
  ) -> list[dict[str, Any]]:
53
  """
54
  Retrieve top-k products for *query*.
 
58
  2. Pre-filter catalog by price / category if applicable.
59
  3. Rank remaining items by cosine similarity.
60
  4. Return top-k with scores.
61
+
62
+ Pass *query_emb* to skip re-encoding when the caller already has it.
63
  """
64
  price_cap = self._extract_price_cap(query)
65
  cat_hint = category_filter or self._extract_category_hint(query)
 
68
  candidate_indices = self._prefilter(price_cap, cat_hint)
69
 
70
  # Stage 2 — semantic ranking over candidates
71
+ if query_emb is None:
72
+ query_emb = self.model.encode([query], show_progress_bar=False)[0]
73
  query_norm = np.linalg.norm(query_emb)
74
 
75
  if len(candidate_indices) == 0:
modules/shared.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Shared SentenceTransformer singleton — loaded once, used everywhere."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+
7
+ from sentence_transformers import SentenceTransformer
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ _model: SentenceTransformer | None = None
12
+
13
+
14
+ def get_embedding_model() -> SentenceTransformer:
15
+ global _model
16
+ if _model is None:
17
+ logger.info("Loading SentenceTransformer (all-MiniLM-L6-v2)…")
18
+ _model = SentenceTransformer("all-MiniLM-L6-v2")
19
+ logger.info("SentenceTransformer ready.")
20
+ return _model