Spaces:
Sleeping
Sleeping
File size: 4,279 Bytes
0b26499 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 | from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Dict, Literal
try:
from transformers import pipeline
except Exception: # pragma: no cover
pipeline = None
DEFAULT_MODEL_NAME = "MoritzLaurer/DeBERTa-v3-xsmall-mnli-fever-anli-ling-binary"
@dataclass
class CompatResult:
status: Literal["compatible", "incompatible", "unknown"]
compatible: bool
score: float
label: str
model_name: str
def to_dict(self) -> Dict[str, Any]:
return {
"status": self.status,
"compatible": self.compatible,
"score": self.score,
"label": self.label,
"model_name": self.model_name,
}
class CompatibilityGate:
def __init__(
self,
model_name: str = DEFAULT_MODEL_NAME,
enable_download: bool = True,
compatible_threshold: float = 0.70,
incompatible_threshold: float = 0.70,
):
self.model_name = model_name or DEFAULT_MODEL_NAME
self.enable_download = enable_download
self.compatible_threshold = compatible_threshold
self.incompatible_threshold = incompatible_threshold
self.available = False
self._kind = "disabled"
self._pipe = None
def _load(self) -> None:
if pipeline is None:
self.available = False
self._kind = "unavailable"
return
try:
self._pipe = pipeline(
"zero-shot-classification",
model=self.model_name,
device=-1,
)
self.available = True
self._kind = "zero-shot"
except Exception:
self._pipe = None
self.available = False
self._kind = "disabled"
def check(self, ingredient: str, diet: str) -> CompatResult:
if not self.available or self._pipe is None:
self._load()
if not self.available or self._pipe is None:
return CompatResult(
status="unknown",
compatible=False,
score=0.0,
label="unavailable",
model_name=self.model_name,
)
ingredient = (ingredient or "").strip()
if not ingredient:
return CompatResult(
status="unknown",
compatible=False,
score=0.0,
label="empty",
model_name=self.model_name,
)
diet = (diet or "vegan").strip().lower()
hypothesis_template = f"This ingredient is {{}} with a {diet} diet."
try:
result = self._pipe(
ingredient,
candidate_labels=["compatible", "not compatible"],
hypothesis_template=hypothesis_template,
)
except Exception:
return CompatResult(
status="unknown",
compatible=False,
score=0.0,
label="error",
model_name=self.model_name,
)
labels = result.get("labels", [])
scores = result.get("scores", [])
if not labels or not scores:
return CompatResult(
status="unknown",
compatible=False,
score=0.0,
label="empty",
model_name=self.model_name,
)
label = str(labels[0])
score = float(scores[0])
if label == "compatible" and score >= self.compatible_threshold:
return CompatResult(
status="compatible",
compatible=True,
score=score,
label=label,
model_name=self.model_name,
)
if label == "not compatible" and score >= self.incompatible_threshold:
return CompatResult(
status="incompatible",
compatible=False,
score=score,
label=label,
model_name=self.model_name,
)
return CompatResult(
status="unknown",
compatible=False,
score=score,
label=label,
model_name=self.model_name,
)
|