| from __future__ import annotations |
|
|
| from functools import lru_cache |
| from typing import Literal |
|
|
| import torch |
| from transformers import AutoModelForQuestionAnswering, AutoTokenizer, pipeline |
|
|
| from .config import settings |
| from .text_utils import normalize_text |
|
|
| RecipeType = Literal["baked", "cooked"] |
|
|
| BAKE_KEYWORDS = [ |
| "bake", "baking", "oven", "preheat", "flour", "dough", "batter", |
| "cake", "cookie", "muffin", "bread", "pastry", "brownie", "tart", |
| "pie", "scone", "loaf", "whisk", "fold in", "sift", "knead", |
| "leavening", "baking soda", "baking powder", "yeast", |
| ] |
| COOK_KEYWORDS = [ |
| "saute", "sauté", "fry", "boil", "simmer", "stir", "grill", |
| "roast", "steam", "poach", "braise", "sear", "stove", "skillet", |
| "pan", "wok", "sauce", "soup", "stew", "marinate", |
| ] |
|
|
|
|
| @lru_cache(maxsize=1) |
| def get_qa_pipeline(): |
| tokenizer = AutoTokenizer.from_pretrained(settings.qa_model_name) |
| model = AutoModelForQuestionAnswering.from_pretrained(settings.qa_model_name) |
| device = 0 if torch.cuda.is_available() else -1 |
| return pipeline( |
| "question-answering", |
| model=model, |
| tokenizer=tokenizer, |
| device=device, |
| ) |
|
|
|
|
| def classify_recipe(recipe_text: str) -> RecipeType: |
| text = normalize_text(recipe_text) |
|
|
| bake_score = sum(1 for kw in BAKE_KEYWORDS if kw in text) |
| cook_score = sum(1 for kw in COOK_KEYWORDS if kw in text) |
|
|
| answer = "" |
| try: |
| qa = get_qa_pipeline() |
| result = qa(question="Is this recipe for baking or cooking?", context=recipe_text) |
| answer = normalize_text(str(result.get("answer", ""))) |
| except Exception: |
| pass |
|
|
| if any(sig in answer for sig in ("bak", "oven", "pastry", "dough")): |
| return "baked" |
| if any(sig in answer for sig in ("cook", "fry", "boil", "saut", "grill", "stir")): |
| return "cooked" |
|
|
| if bake_score > cook_score: |
| return "baked" |
| if cook_score > bake_score: |
| return "cooked" |
|
|
| return "cooked" |
|
|