Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from typing import Any, Dict, Literal | |
| try: | |
| from transformers import pipeline | |
| except Exception: # pragma: no cover | |
| pipeline = None | |
| DEFAULT_MODEL_NAME = "MoritzLaurer/DeBERTa-v3-xsmall-mnli-fever-anli-ling-binary" | |
| class CompatResult: | |
| status: Literal["compatible", "incompatible", "unknown"] | |
| compatible: bool | |
| score: float | |
| label: str | |
| model_name: str | |
| def to_dict(self) -> Dict[str, Any]: | |
| return { | |
| "status": self.status, | |
| "compatible": self.compatible, | |
| "score": self.score, | |
| "label": self.label, | |
| "model_name": self.model_name, | |
| } | |
| class CompatibilityGate: | |
| def __init__( | |
| self, | |
| model_name: str = DEFAULT_MODEL_NAME, | |
| enable_download: bool = True, | |
| compatible_threshold: float = 0.70, | |
| incompatible_threshold: float = 0.70, | |
| ): | |
| self.model_name = model_name or DEFAULT_MODEL_NAME | |
| self.enable_download = enable_download | |
| self.compatible_threshold = compatible_threshold | |
| self.incompatible_threshold = incompatible_threshold | |
| self.available = False | |
| self._kind = "disabled" | |
| self._pipe = None | |
| def _load(self) -> None: | |
| if pipeline is None: | |
| self.available = False | |
| self._kind = "unavailable" | |
| return | |
| try: | |
| self._pipe = pipeline( | |
| "zero-shot-classification", | |
| model=self.model_name, | |
| device=-1, | |
| ) | |
| self.available = True | |
| self._kind = "zero-shot" | |
| except Exception: | |
| self._pipe = None | |
| self.available = False | |
| self._kind = "disabled" | |
| def check(self, ingredient: str, diet: str) -> CompatResult: | |
| if not self.available or self._pipe is None: | |
| self._load() | |
| if not self.available or self._pipe is None: | |
| return CompatResult( | |
| status="unknown", | |
| compatible=False, | |
| score=0.0, | |
| label="unavailable", | |
| model_name=self.model_name, | |
| ) | |
| ingredient = (ingredient or "").strip() | |
| if not ingredient: | |
| return CompatResult( | |
| status="unknown", | |
| compatible=False, | |
| score=0.0, | |
| label="empty", | |
| model_name=self.model_name, | |
| ) | |
| diet = (diet or "vegan").strip().lower() | |
| hypothesis_template = f"This ingredient is {{}} with a {diet} diet." | |
| try: | |
| result = self._pipe( | |
| ingredient, | |
| candidate_labels=["compatible", "not compatible"], | |
| hypothesis_template=hypothesis_template, | |
| ) | |
| except Exception: | |
| return CompatResult( | |
| status="unknown", | |
| compatible=False, | |
| score=0.0, | |
| label="error", | |
| model_name=self.model_name, | |
| ) | |
| labels = result.get("labels", []) | |
| scores = result.get("scores", []) | |
| if not labels or not scores: | |
| return CompatResult( | |
| status="unknown", | |
| compatible=False, | |
| score=0.0, | |
| label="empty", | |
| model_name=self.model_name, | |
| ) | |
| label = str(labels[0]) | |
| score = float(scores[0]) | |
| if label == "compatible" and score >= self.compatible_threshold: | |
| return CompatResult( | |
| status="compatible", | |
| compatible=True, | |
| score=score, | |
| label=label, | |
| model_name=self.model_name, | |
| ) | |
| if label == "not compatible" and score >= self.incompatible_threshold: | |
| return CompatResult( | |
| status="incompatible", | |
| compatible=False, | |
| score=score, | |
| label=label, | |
| model_name=self.model_name, | |
| ) | |
| return CompatResult( | |
| status="unknown", | |
| compatible=False, | |
| score=score, | |
| label=label, | |
| model_name=self.model_name, | |
| ) | |