| from __future__ import annotations |
|
|
| from functools import lru_cache |
| from typing import List |
|
|
| import torch |
| from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline |
|
|
| from .config import settings |
| from .text_utils import dedupe_preserve_order, strip_amounts_and_preps, tokenize_recipe_segments |
|
|
|
|
| @lru_cache(maxsize=1) |
| def get_ner_pipeline(): |
| tokenizer = AutoTokenizer.from_pretrained(settings.ner_model_name) |
| model = AutoModelForTokenClassification.from_pretrained(settings.ner_model_name) |
| device = 0 if torch.cuda.is_available() else -1 |
| return pipeline( |
| "token-classification", |
| model=model, |
| tokenizer=tokenizer, |
| aggregation_strategy="simple", |
| device=device, |
| ) |
|
|
|
|
| def _merge_adjacent_entities(entities: List[dict]) -> List[dict]: |
| merged: List[dict] = [] |
| for ent in entities: |
| start = int(ent.get("start", 0)) |
| end = int(ent.get("end", 0)) |
| label = str(ent.get("entity_group", "")) |
| if not label or end <= start: |
| continue |
|
|
| if merged and merged[-1]["label"] == label and -1 <= start - merged[-1]["end"] <= 1: |
| merged[-1]["end"] = end |
| continue |
|
|
| merged.append({"start": start, "end": end, "label": label}) |
| return merged |
|
|
|
|
| def _extract_spans(text: str) -> List[str]: |
| try: |
| ner = get_ner_pipeline() |
| entities = ner(text) |
| except Exception: |
| return [] |
|
|
| merged = _merge_adjacent_entities(entities) |
| spans: List[str] = [] |
| for ent in merged: |
| span = text[ent["start"]:ent["end"]].strip(" ,.;:\n\t") |
| span = strip_amounts_and_preps(span) |
| if len(span) >= 3: |
| spans.append(span) |
| return dedupe_preserve_order(spans) |
|
|
|
|
| def extract_ingredients(text: str, max_items: int = 48) -> List[str]: |
| """Extract ingredient-like tokens from a recipe. |
| |
| We keep the MVP behaviour simple: split the recipe into comma-separated |
| chunks, run RoBERTa NER per chunk, and fall back to the cleaned chunk when |
| the model misses it. |
| """ |
| chunks = tokenize_recipe_segments(text) |
| if not chunks: |
| chunks = [strip_amounts_and_preps(text)] |
|
|
| extracted: List[str] = [] |
| for chunk in chunks: |
| if not chunk: |
| continue |
|
|
| spans = _extract_spans(chunk) |
| if spans: |
| extracted.extend(spans) |
| else: |
| cleaned = strip_amounts_and_preps(chunk) |
| if cleaned and len(cleaned) >= 2: |
| extracted.append(cleaned) |
|
|
| |
| cleaned = [] |
| for item in dedupe_preserve_order(extracted): |
| if len(item) < 3: |
| continue |
| if item in {"and", "or", "the", "a", "an"}: |
| continue |
| cleaned.append(item) |
|
|
| return cleaned[:max_items] |
|
|