from __future__ import annotations from functools import lru_cache from typing import List import torch from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline from .config import settings from .text_utils import dedupe_preserve_order, strip_amounts_and_preps, tokenize_recipe_segments @lru_cache(maxsize=1) def get_ner_pipeline(): tokenizer = AutoTokenizer.from_pretrained(settings.ner_model_name) model = AutoModelForTokenClassification.from_pretrained(settings.ner_model_name) device = 0 if torch.cuda.is_available() else -1 return pipeline( "token-classification", model=model, tokenizer=tokenizer, aggregation_strategy="simple", device=device, ) def _merge_adjacent_entities(entities: List[dict]) -> List[dict]: merged: List[dict] = [] for ent in entities: start = int(ent.get("start", 0)) end = int(ent.get("end", 0)) label = str(ent.get("entity_group", "")) if not label or end <= start: continue if merged and merged[-1]["label"] == label and -1 <= start - merged[-1]["end"] <= 1: merged[-1]["end"] = end continue merged.append({"start": start, "end": end, "label": label}) return merged def _extract_spans(text: str) -> List[str]: try: ner = get_ner_pipeline() entities = ner(text) except Exception: return [] merged = _merge_adjacent_entities(entities) spans: List[str] = [] for ent in merged: span = text[ent["start"]:ent["end"]].strip(" ,.;:\n\t") span = strip_amounts_and_preps(span) if len(span) >= 3: spans.append(span) return dedupe_preserve_order(spans) def extract_ingredients(text: str, max_items: int = 48) -> List[str]: """Extract ingredient-like tokens from a recipe. We keep the MVP behaviour simple: split the recipe into comma-separated chunks, run RoBERTa NER per chunk, and fall back to the cleaned chunk when the model misses it. """ chunks = tokenize_recipe_segments(text) if not chunks: chunks = [strip_amounts_and_preps(text)] extracted: List[str] = [] for chunk in chunks: if not chunk: continue spans = _extract_spans(chunk) if spans: extracted.extend(spans) else: cleaned = strip_amounts_and_preps(chunk) if cleaned and len(cleaned) >= 2: extracted.append(cleaned) # Keep only meaningful spans; short fragments tend to be false positives. cleaned = [] for item in dedupe_preserve_order(extracted): if len(item) < 3: continue if item in {"and", "or", "the", "a", "an"}: continue cleaned.append(item) return cleaned[:max_items]