anaygupta commited on
Commit
244053a
·
verified ·
1 Parent(s): f43f12b

Update services/semantic.py

Browse files
Files changed (1) hide show
  1. services/semantic.py +76 -80
services/semantic.py CHANGED
@@ -1,107 +1,103 @@
1
  from __future__ import annotations
2
 
3
  from dataclasses import dataclass
4
- from typing import Iterable, List, Optional
 
5
 
6
- import numpy as np
7
 
8
- from .text_utils import normalize_text, singularize
 
 
 
 
 
 
9
 
10
 
11
  @dataclass
12
- class SemanticHit:
13
  term: str
14
  score: float
15
 
16
 
17
  class WordVectorFallback:
18
- """Small Glove-based semantic fallback.
19
-
20
- The model is optional so the app can still boot if the host blocks downloads.
21
- """
22
-
23
- def __init__(self, model_name: str = "glove-wiki-gigaword-50", model_path: str = "", enable_download: bool = True):
24
- self.model = None
25
  self._kind = "disabled"
26
- self._model_name = model_name
27
- self._load(model_name=model_name, model_path=model_path, enable_download=enable_download)
28
 
29
- def _load(self, model_name: str, model_path: str, enable_download: bool) -> None:
30
- try:
31
- from gensim.models import KeyedVectors
32
- import gensim.downloader as api
33
- except Exception:
34
- self.model = None
35
  self._kind = "unavailable"
36
  return
37
 
38
- if model_path:
39
- try:
40
- self.model = KeyedVectors.load(model_path, mmap="r")
41
- self._kind = f"local:{model_path}"
42
- return
43
- except Exception:
44
- try:
45
- self.model = KeyedVectors.load_word2vec_format(model_path, binary=model_path.endswith(".bin"))
46
- self._kind = f"local-vec:{model_path}"
47
- return
48
- except Exception:
49
- self.model = None
50
-
51
- if enable_download:
52
- try:
53
- self.model = api.load(model_name)
54
- self._kind = model_name
55
- except Exception:
56
- self.model = None
57
- self._kind = "download-failed"
58
- else:
59
- self.model = None
60
  self._kind = "disabled"
61
 
62
- @property
63
- def available(self) -> bool:
64
- return self.model is not None
 
65
 
66
- def vector_for(self, phrase: str) -> Optional[np.ndarray]:
67
- if not self.available:
68
- return None
69
-
70
- normalized = normalize_text(phrase)
71
- tokens = [singularize(t) for t in normalized.split()]
72
- vectors = []
73
- for token in tokens:
74
- if token in self.model:
75
- vectors.append(self.model[token])
76
-
77
- if vectors:
78
- return np.mean(np.stack(vectors), axis=0)
79
-
80
- phrase_key = normalized.replace(" ", "_")
81
- if phrase_key in self.model:
82
- return self.model[phrase_key]
83
-
84
- if normalized in self.model:
85
- return self.model[normalized]
86
-
87
- return None
88
 
89
- def nearest(self, query: str, candidates: Iterable[str], top_k: int = 3) -> List[SemanticHit]:
90
- if not self.available:
91
  return []
92
 
93
- qv = self.vector_for(query)
94
- if qv is None:
 
95
  return []
96
 
97
- scored: List[SemanticHit] = []
98
- qnorm = np.linalg.norm(qv) + 1e-8
99
- for candidate in candidates:
100
- cv = self.vector_for(candidate)
101
- if cv is None:
102
  continue
103
- score = float(np.dot(qv, cv) / (qnorm * (np.linalg.norm(cv) + 1e-8)))
104
- scored.append(SemanticHit(term=candidate, score=score))
105
-
106
- scored.sort(key=lambda x: x.score, reverse=True)
107
- return scored[:top_k]
 
 
 
 
 
 
 
 
 
 
 
 
1
  from __future__ import annotations
2
 
3
  from dataclasses import dataclass
4
+ from pathlib import Path
5
+ from typing import List, Tuple
6
 
7
+ from .text_utils import normalize_text, singularize, ingredient_lookup_variants
8
 
9
+ try:
10
+ import gensim.downloader as api
11
+ except Exception: # pragma: no cover
12
+ api = None
13
+
14
+
15
+ DEFAULT_MODEL = "glove-wiki-gigaword-50"
16
 
17
 
18
  @dataclass
19
+ class SemanticCandidate:
20
  term: str
21
  score: float
22
 
23
 
24
  class WordVectorFallback:
25
+ def __init__(self, model_name: str = DEFAULT_MODEL, model_path: str | None = None, enable_download: bool = True):
26
+ self.model_name = model_name or DEFAULT_MODEL
27
+ self.model_path = model_path
28
+ self.enable_download = enable_download
29
+ self.available = False
 
 
30
  self._kind = "disabled"
31
+ self._model = None
 
32
 
33
+ self._load()
34
+
35
+ def _load(self) -> None:
36
+ if api is None:
37
+ self.available = False
 
38
  self._kind = "unavailable"
39
  return
40
 
41
+ try:
42
+ if self.model_path:
43
+ path = Path(self.model_path)
44
+ if path.exists():
45
+ # Keep this permissive; local path loading is optional.
46
+ self._model = api.load(self.model_name)
47
+ elif self.enable_download:
48
+ self._model = api.load(self.model_name)
49
+ else:
50
+ self._model = None
51
+ else:
52
+ if self.enable_download:
53
+ self._model = api.load(self.model_name)
54
+ else:
55
+ self._model = None
56
+
57
+ self.available = self._model is not None
58
+ self._kind = "glove" if self.available else "disabled"
59
+ except Exception:
60
+ self._model = None
61
+ self.available = False
 
62
  self._kind = "disabled"
63
 
64
+ def _normalize_candidate(self, term: str) -> str:
65
+ term = normalize_text(term)
66
+ term = singularize(term)
67
+ return term
68
 
69
+ def most_similar(self, ingredient: str, topn: int = 10) -> List[Tuple[str, float]]:
70
+ if not self.available or self._model is None:
71
+ return []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
+ query = self._normalize_candidate(ingredient)
74
+ if not query:
75
  return []
76
 
77
+ try:
78
+ raw = self._model.most_similar(query, topn=max(topn, 10))
79
+ except Exception:
80
  return []
81
 
82
+ out: List[Tuple[str, float]] = []
83
+ seen = set()
84
+ for term, score in raw:
85
+ term = self._normalize_candidate(term.replace("_", " "))
86
+ if not term or term in seen:
87
  continue
88
+ seen.add(term)
89
+
90
+ # Keep only candidates with at least some lexical overlap or clear phrase family.
91
+ query_parts = set(query.split())
92
+ term_parts = set(term.split())
93
+ if query_parts and term_parts and not (query_parts & term_parts):
94
+ # Allow the last fallback word variant to pass through only if it looks like a food term.
95
+ variants = ingredient_lookup_variants(term)
96
+ if not variants:
97
+ continue
98
+
99
+ out.append((term, float(score)))
100
+ if len(out) >= topn:
101
+ break
102
+
103
+ return out