from __future__ import annotations from functools import lru_cache from typing import Literal import torch from transformers import AutoModelForQuestionAnswering, AutoTokenizer, pipeline from .config import settings from .text_utils import normalize_text RecipeType = Literal["baked", "cooked"] BAKE_KEYWORDS = [ "bake", "baking", "oven", "preheat", "flour", "dough", "batter", "cake", "cookie", "muffin", "bread", "pastry", "brownie", "tart", "pie", "scone", "loaf", "whisk", "fold in", "sift", "knead", "leavening", "baking soda", "baking powder", "yeast", ] COOK_KEYWORDS = [ "saute", "sauté", "fry", "boil", "simmer", "stir", "grill", "roast", "steam", "poach", "braise", "sear", "stove", "skillet", "pan", "wok", "sauce", "soup", "stew", "marinate", ] @lru_cache(maxsize=1) def get_qa_pipeline(): tokenizer = AutoTokenizer.from_pretrained(settings.qa_model_name) model = AutoModelForQuestionAnswering.from_pretrained(settings.qa_model_name) device = 0 if torch.cuda.is_available() else -1 return pipeline( "question-answering", model=model, tokenizer=tokenizer, device=device, ) def classify_recipe(recipe_text: str) -> RecipeType: text = normalize_text(recipe_text) bake_score = sum(1 for kw in BAKE_KEYWORDS if kw in text) cook_score = sum(1 for kw in COOK_KEYWORDS if kw in text) answer = "" try: qa = get_qa_pipeline() result = qa(question="Is this recipe for baking or cooking?", context=recipe_text) answer = normalize_text(str(result.get("answer", ""))) except Exception: pass if any(sig in answer for sig in ("bak", "oven", "pastry", "dough")): return "baked" if any(sig in answer for sig in ("cook", "fry", "boil", "saut", "grill", "stir")): return "cooked" if bake_score > cook_score: return "baked" if cook_score > bake_score: return "cooked" return "cooked"