anaygupta commited on
Commit
0b26499
·
verified ·
1 Parent(s): c5e7cb1

Upload services_compat.py

Browse files
Files changed (1) hide show
  1. services/services_compat.py +147 -0
services/services_compat.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any, Dict, Literal
5
+
6
+ try:
7
+ from transformers import pipeline
8
+ except Exception: # pragma: no cover
9
+ pipeline = None
10
+
11
+
12
+ DEFAULT_MODEL_NAME = "MoritzLaurer/DeBERTa-v3-xsmall-mnli-fever-anli-ling-binary"
13
+
14
+
15
+ @dataclass
16
+ class CompatResult:
17
+ status: Literal["compatible", "incompatible", "unknown"]
18
+ compatible: bool
19
+ score: float
20
+ label: str
21
+ model_name: str
22
+
23
+ def to_dict(self) -> Dict[str, Any]:
24
+ return {
25
+ "status": self.status,
26
+ "compatible": self.compatible,
27
+ "score": self.score,
28
+ "label": self.label,
29
+ "model_name": self.model_name,
30
+ }
31
+
32
+
33
+ class CompatibilityGate:
34
+ def __init__(
35
+ self,
36
+ model_name: str = DEFAULT_MODEL_NAME,
37
+ enable_download: bool = True,
38
+ compatible_threshold: float = 0.70,
39
+ incompatible_threshold: float = 0.70,
40
+ ):
41
+ self.model_name = model_name or DEFAULT_MODEL_NAME
42
+ self.enable_download = enable_download
43
+ self.compatible_threshold = compatible_threshold
44
+ self.incompatible_threshold = incompatible_threshold
45
+ self.available = False
46
+ self._kind = "disabled"
47
+ self._pipe = None
48
+
49
+ def _load(self) -> None:
50
+ if pipeline is None:
51
+ self.available = False
52
+ self._kind = "unavailable"
53
+ return
54
+
55
+ try:
56
+ self._pipe = pipeline(
57
+ "zero-shot-classification",
58
+ model=self.model_name,
59
+ device=-1,
60
+ )
61
+ self.available = True
62
+ self._kind = "zero-shot"
63
+ except Exception:
64
+ self._pipe = None
65
+ self.available = False
66
+ self._kind = "disabled"
67
+
68
+ def check(self, ingredient: str, diet: str) -> CompatResult:
69
+ if not self.available or self._pipe is None:
70
+ self._load()
71
+
72
+ if not self.available or self._pipe is None:
73
+ return CompatResult(
74
+ status="unknown",
75
+ compatible=False,
76
+ score=0.0,
77
+ label="unavailable",
78
+ model_name=self.model_name,
79
+ )
80
+
81
+ ingredient = (ingredient or "").strip()
82
+ if not ingredient:
83
+ return CompatResult(
84
+ status="unknown",
85
+ compatible=False,
86
+ score=0.0,
87
+ label="empty",
88
+ model_name=self.model_name,
89
+ )
90
+
91
+ diet = (diet or "vegan").strip().lower()
92
+ hypothesis_template = f"This ingredient is {{}} with a {diet} diet."
93
+
94
+ try:
95
+ result = self._pipe(
96
+ ingredient,
97
+ candidate_labels=["compatible", "not compatible"],
98
+ hypothesis_template=hypothesis_template,
99
+ )
100
+ except Exception:
101
+ return CompatResult(
102
+ status="unknown",
103
+ compatible=False,
104
+ score=0.0,
105
+ label="error",
106
+ model_name=self.model_name,
107
+ )
108
+
109
+ labels = result.get("labels", [])
110
+ scores = result.get("scores", [])
111
+ if not labels or not scores:
112
+ return CompatResult(
113
+ status="unknown",
114
+ compatible=False,
115
+ score=0.0,
116
+ label="empty",
117
+ model_name=self.model_name,
118
+ )
119
+
120
+ label = str(labels[0])
121
+ score = float(scores[0])
122
+
123
+ if label == "compatible" and score >= self.compatible_threshold:
124
+ return CompatResult(
125
+ status="compatible",
126
+ compatible=True,
127
+ score=score,
128
+ label=label,
129
+ model_name=self.model_name,
130
+ )
131
+
132
+ if label == "not compatible" and score >= self.incompatible_threshold:
133
+ return CompatResult(
134
+ status="incompatible",
135
+ compatible=False,
136
+ score=score,
137
+ label=label,
138
+ model_name=self.model_name,
139
+ )
140
+
141
+ return CompatResult(
142
+ status="unknown",
143
+ compatible=False,
144
+ score=score,
145
+ label=label,
146
+ model_name=self.model_name,
147
+ )