Spaces:
Sleeping
Sleeping
perf: reduce per-query latency — shared model, single encode, float32 LLM
Browse files- Add modules/shared.py: single SentenceTransformer singleton loaded once
- HybridRetriever and DriftDetector now share the same model instance
instead of each loading a separate copy (saves ~90MB RAM + load time)
- Encode query embedding once in process_query and pass to both
analyze_drift() and search() — eliminates redundant encode call
- Switch LLM to torch.float32 (bfloat16 has no CPU hardware acceleration)
- Reduce max_new_tokens 120→80 for faster token generation
- app.py +6 -2
- modules/drift.py +10 -15
- modules/llm.py +2 -2
- modules/retrieval.py +8 -3
- modules/shared.py +20 -0
app.py
CHANGED
|
@@ -11,6 +11,7 @@ import sys
|
|
| 11 |
import gradio as gr
|
| 12 |
import plotly.graph_objects as go
|
| 13 |
from modules.data_simulation import generate_catalog, get_scenarios
|
|
|
|
| 14 |
from modules.retrieval import HybridRetriever
|
| 15 |
from modules.drift import DriftDetector
|
| 16 |
from modules.adaptation import Adapter
|
|
@@ -229,11 +230,14 @@ def process_query(query: str, history: list):
|
|
| 229 |
|
| 230 |
logger.info("Processing query: %r", query)
|
| 231 |
|
|
|
|
|
|
|
|
|
|
| 232 |
# 1. Measure drift
|
| 233 |
-
drift_state, scores = detector.analyze_drift(query)
|
| 234 |
|
| 235 |
# 2. Retrieve products (hybrid: price-filter + semantic)
|
| 236 |
-
retrieved = retriever.search(query, top_k=4)
|
| 237 |
|
| 238 |
# 3. Adapt system prompt
|
| 239 |
system_prompt = adapter.adapt_prompt(drift_state)
|
|
|
|
| 11 |
import gradio as gr
|
| 12 |
import plotly.graph_objects as go
|
| 13 |
from modules.data_simulation import generate_catalog, get_scenarios
|
| 14 |
+
from modules.shared import get_embedding_model
|
| 15 |
from modules.retrieval import HybridRetriever
|
| 16 |
from modules.drift import DriftDetector
|
| 17 |
from modules.adaptation import Adapter
|
|
|
|
| 230 |
|
| 231 |
logger.info("Processing query: %r", query)
|
| 232 |
|
| 233 |
+
# Encode query once — shared by drift detection and retrieval
|
| 234 |
+
query_emb = get_embedding_model().encode([query], show_progress_bar=False)[0]
|
| 235 |
+
|
| 236 |
# 1. Measure drift
|
| 237 |
+
drift_state, scores = detector.analyze_drift(query, query_emb=query_emb)
|
| 238 |
|
| 239 |
# 2. Retrieve products (hybrid: price-filter + semantic)
|
| 240 |
+
retrieved = retriever.search(query, top_k=4, query_emb=query_emb)
|
| 241 |
|
| 242 |
# 3. Adapt system prompt
|
| 243 |
system_prompt = adapter.adapt_prompt(drift_state)
|
modules/drift.py
CHANGED
|
@@ -16,19 +16,10 @@ from dataclasses import dataclass, field
|
|
| 16 |
from typing import Any
|
| 17 |
|
| 18 |
import numpy as np
|
| 19 |
-
from sentence_transformers import SentenceTransformer
|
| 20 |
-
|
| 21 |
-
logger = logging.getLogger(__name__)
|
| 22 |
-
|
| 23 |
-
# Use shared model instance across retriever & drift detector
|
| 24 |
-
_shared_model: SentenceTransformer | None = None
|
| 25 |
|
|
|
|
| 26 |
|
| 27 |
-
|
| 28 |
-
global _shared_model
|
| 29 |
-
if _shared_model is None:
|
| 30 |
-
_shared_model = SentenceTransformer("all-MiniLM-L6-v2")
|
| 31 |
-
return _shared_model
|
| 32 |
|
| 33 |
|
| 34 |
@dataclass
|
|
@@ -57,7 +48,7 @@ class DriftDetector:
|
|
| 57 |
_concept_embs: dict[str, Any] = field(default_factory=dict, repr=False)
|
| 58 |
|
| 59 |
def __post_init__(self) -> None:
|
| 60 |
-
model =
|
| 61 |
# Multiple anchor phrases per concept → averaged embedding for robustness
|
| 62 |
concept_phrases = {
|
| 63 |
"price_sensitive": [
|
|
@@ -95,13 +86,17 @@ class DriftDetector:
|
|
| 95 |
self._ewma[c] = 0.15
|
| 96 |
# ── Public API ──────────────────────────────────────────────────────────
|
| 97 |
|
| 98 |
-
def analyze_drift(
|
|
|
|
|
|
|
| 99 |
"""
|
| 100 |
Score *query* against all concept anchors and return
|
| 101 |
``(dominant_concept, raw_scores)``.
|
|
|
|
|
|
|
| 102 |
"""
|
| 103 |
-
|
| 104 |
-
|
| 105 |
|
| 106 |
raw_scores: dict[str, float] = {}
|
| 107 |
for concept, ref_emb in self._concept_embs.items():
|
|
|
|
| 16 |
from typing import Any
|
| 17 |
|
| 18 |
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
+
from modules.shared import get_embedding_model
|
| 21 |
|
| 22 |
+
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
|
| 25 |
@dataclass
|
|
|
|
| 48 |
_concept_embs: dict[str, Any] = field(default_factory=dict, repr=False)
|
| 49 |
|
| 50 |
def __post_init__(self) -> None:
|
| 51 |
+
model = get_embedding_model()
|
| 52 |
# Multiple anchor phrases per concept → averaged embedding for robustness
|
| 53 |
concept_phrases = {
|
| 54 |
"price_sensitive": [
|
|
|
|
| 86 |
self._ewma[c] = 0.15
|
| 87 |
# ── Public API ──────────────────────────────────────────────────────────
|
| 88 |
|
| 89 |
+
def analyze_drift(
|
| 90 |
+
self, query: str, query_emb=None
|
| 91 |
+
) -> tuple[str, dict[str, float]]:
|
| 92 |
"""
|
| 93 |
Score *query* against all concept anchors and return
|
| 94 |
``(dominant_concept, raw_scores)``.
|
| 95 |
+
|
| 96 |
+
Pass *query_emb* to skip re-encoding when the caller already has it.
|
| 97 |
"""
|
| 98 |
+
if query_emb is None:
|
| 99 |
+
query_emb = get_embedding_model().encode([query], show_progress_bar=False)[0]
|
| 100 |
|
| 101 |
raw_scores: dict[str, float] = {}
|
| 102 |
for concept, ref_emb in self._concept_embs.items():
|
modules/llm.py
CHANGED
|
@@ -30,7 +30,7 @@ def _get_pipeline():
|
|
| 30 |
"text-generation",
|
| 31 |
model="Qwen/Qwen2.5-0.5B-Instruct",
|
| 32 |
device="cpu",
|
| 33 |
-
torch_dtype=torch.
|
| 34 |
)
|
| 35 |
logger.info("Model loaded in %.1fs", time.time() - t0)
|
| 36 |
return _generator
|
|
@@ -81,7 +81,7 @@ def generate_response(
|
|
| 81 |
gen = _get_pipeline()
|
| 82 |
result = gen(
|
| 83 |
messages,
|
| 84 |
-
max_new_tokens=
|
| 85 |
do_sample=False,
|
| 86 |
return_full_text=False,
|
| 87 |
)
|
|
|
|
| 30 |
"text-generation",
|
| 31 |
model="Qwen/Qwen2.5-0.5B-Instruct",
|
| 32 |
device="cpu",
|
| 33 |
+
torch_dtype=torch.float32,
|
| 34 |
)
|
| 35 |
logger.info("Model loaded in %.1fs", time.time() - t0)
|
| 36 |
return _generator
|
|
|
|
| 81 |
gen = _get_pipeline()
|
| 82 |
result = gen(
|
| 83 |
messages,
|
| 84 |
+
max_new_tokens=80,
|
| 85 |
do_sample=False,
|
| 86 |
return_full_text=False,
|
| 87 |
)
|
modules/retrieval.py
CHANGED
|
@@ -13,7 +13,8 @@ import re
|
|
| 13 |
from typing import Any
|
| 14 |
|
| 15 |
import numpy as np
|
| 16 |
-
|
|
|
|
| 17 |
|
| 18 |
logger = logging.getLogger(__name__)
|
| 19 |
|
|
@@ -23,7 +24,7 @@ class HybridRetriever:
|
|
| 23 |
|
| 24 |
def __init__(self, catalog: list[dict]) -> None:
|
| 25 |
self.catalog = catalog
|
| 26 |
-
self.model =
|
| 27 |
|
| 28 |
# Build rich embedding texts that capture all searchable facets
|
| 29 |
texts = [
|
|
@@ -47,6 +48,7 @@ class HybridRetriever:
|
|
| 47 |
query: str,
|
| 48 |
top_k: int = 4,
|
| 49 |
category_filter: str | None = None,
|
|
|
|
| 50 |
) -> list[dict[str, Any]]:
|
| 51 |
"""
|
| 52 |
Retrieve top-k products for *query*.
|
|
@@ -56,6 +58,8 @@ class HybridRetriever:
|
|
| 56 |
2. Pre-filter catalog by price / category if applicable.
|
| 57 |
3. Rank remaining items by cosine similarity.
|
| 58 |
4. Return top-k with scores.
|
|
|
|
|
|
|
| 59 |
"""
|
| 60 |
price_cap = self._extract_price_cap(query)
|
| 61 |
cat_hint = category_filter or self._extract_category_hint(query)
|
|
@@ -64,7 +68,8 @@ class HybridRetriever:
|
|
| 64 |
candidate_indices = self._prefilter(price_cap, cat_hint)
|
| 65 |
|
| 66 |
# Stage 2 — semantic ranking over candidates
|
| 67 |
-
query_emb
|
|
|
|
| 68 |
query_norm = np.linalg.norm(query_emb)
|
| 69 |
|
| 70 |
if len(candidate_indices) == 0:
|
|
|
|
| 13 |
from typing import Any
|
| 14 |
|
| 15 |
import numpy as np
|
| 16 |
+
|
| 17 |
+
from modules.shared import get_embedding_model
|
| 18 |
|
| 19 |
logger = logging.getLogger(__name__)
|
| 20 |
|
|
|
|
| 24 |
|
| 25 |
def __init__(self, catalog: list[dict]) -> None:
|
| 26 |
self.catalog = catalog
|
| 27 |
+
self.model = get_embedding_model()
|
| 28 |
|
| 29 |
# Build rich embedding texts that capture all searchable facets
|
| 30 |
texts = [
|
|
|
|
| 48 |
query: str,
|
| 49 |
top_k: int = 4,
|
| 50 |
category_filter: str | None = None,
|
| 51 |
+
query_emb=None,
|
| 52 |
) -> list[dict[str, Any]]:
|
| 53 |
"""
|
| 54 |
Retrieve top-k products for *query*.
|
|
|
|
| 58 |
2. Pre-filter catalog by price / category if applicable.
|
| 59 |
3. Rank remaining items by cosine similarity.
|
| 60 |
4. Return top-k with scores.
|
| 61 |
+
|
| 62 |
+
Pass *query_emb* to skip re-encoding when the caller already has it.
|
| 63 |
"""
|
| 64 |
price_cap = self._extract_price_cap(query)
|
| 65 |
cat_hint = category_filter or self._extract_category_hint(query)
|
|
|
|
| 68 |
candidate_indices = self._prefilter(price_cap, cat_hint)
|
| 69 |
|
| 70 |
# Stage 2 — semantic ranking over candidates
|
| 71 |
+
if query_emb is None:
|
| 72 |
+
query_emb = self.model.encode([query], show_progress_bar=False)[0]
|
| 73 |
query_norm = np.linalg.norm(query_emb)
|
| 74 |
|
| 75 |
if len(candidate_indices) == 0:
|
modules/shared.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Shared SentenceTransformer singleton — loaded once, used everywhere."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
|
| 7 |
+
from sentence_transformers import SentenceTransformer
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
_model: SentenceTransformer | None = None
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def get_embedding_model() -> SentenceTransformer:
|
| 15 |
+
global _model
|
| 16 |
+
if _model is None:
|
| 17 |
+
logger.info("Loading SentenceTransformer (all-MiniLM-L6-v2)…")
|
| 18 |
+
_model = SentenceTransformer("all-MiniLM-L6-v2")
|
| 19 |
+
logger.info("SentenceTransformer ready.")
|
| 20 |
+
return _model
|