from __future__ import annotations from dataclasses import dataclass from typing import Any, Dict, Literal try: from transformers import pipeline except Exception: # pragma: no cover pipeline = None DEFAULT_MODEL_NAME = "MoritzLaurer/DeBERTa-v3-xsmall-mnli-fever-anli-ling-binary" @dataclass class CompatResult: status: Literal["compatible", "incompatible", "unknown"] compatible: bool score: float label: str model_name: str def to_dict(self) -> Dict[str, Any]: return { "status": self.status, "compatible": self.compatible, "score": self.score, "label": self.label, "model_name": self.model_name, } class CompatibilityGate: def __init__( self, model_name: str = DEFAULT_MODEL_NAME, enable_download: bool = True, compatible_threshold: float = 0.70, incompatible_threshold: float = 0.70, ): self.model_name = model_name or DEFAULT_MODEL_NAME self.enable_download = enable_download self.compatible_threshold = compatible_threshold self.incompatible_threshold = incompatible_threshold self.available = False self._kind = "disabled" self._pipe = None def _load(self) -> None: if pipeline is None: self.available = False self._kind = "unavailable" return try: self._pipe = pipeline( "zero-shot-classification", model=self.model_name, device=-1, ) self.available = True self._kind = "zero-shot" except Exception: self._pipe = None self.available = False self._kind = "disabled" def check(self, ingredient: str, diet: str) -> CompatResult: if not self.available or self._pipe is None: self._load() if not self.available or self._pipe is None: return CompatResult( status="unknown", compatible=False, score=0.0, label="unavailable", model_name=self.model_name, ) ingredient = (ingredient or "").strip() if not ingredient: return CompatResult( status="unknown", compatible=False, score=0.0, label="empty", model_name=self.model_name, ) diet = (diet or "vegan").strip().lower() hypothesis_template = f"This ingredient is {{}} with a {diet} diet." try: result = self._pipe( ingredient, candidate_labels=["compatible", "not compatible"], hypothesis_template=hypothesis_template, ) except Exception: return CompatResult( status="unknown", compatible=False, score=0.0, label="error", model_name=self.model_name, ) labels = result.get("labels", []) scores = result.get("scores", []) if not labels or not scores: return CompatResult( status="unknown", compatible=False, score=0.0, label="empty", model_name=self.model_name, ) label = str(labels[0]) score = float(scores[0]) if label == "compatible" and score >= self.compatible_threshold: return CompatResult( status="compatible", compatible=True, score=score, label=label, model_name=self.model_name, ) if label == "not compatible" and score >= self.incompatible_threshold: return CompatResult( status="incompatible", compatible=False, score=score, label=label, model_name=self.model_name, ) return CompatResult( status="unknown", compatible=False, score=score, label=label, model_name=self.model_name, )