bitewise / services /ner.py
anaygupta's picture
Upload 22 files
df8f88e verified
from __future__ import annotations
from functools import lru_cache
from typing import List
import torch
from transformers import AutoModelForTokenClassification, AutoTokenizer, pipeline
from .config import settings
from .text_utils import dedupe_preserve_order, strip_amounts_and_preps, tokenize_recipe_segments
@lru_cache(maxsize=1)
def get_ner_pipeline():
tokenizer = AutoTokenizer.from_pretrained(settings.ner_model_name)
model = AutoModelForTokenClassification.from_pretrained(settings.ner_model_name)
device = 0 if torch.cuda.is_available() else -1
return pipeline(
"token-classification",
model=model,
tokenizer=tokenizer,
aggregation_strategy="simple",
device=device,
)
def _merge_adjacent_entities(entities: List[dict]) -> List[dict]:
merged: List[dict] = []
for ent in entities:
start = int(ent.get("start", 0))
end = int(ent.get("end", 0))
label = str(ent.get("entity_group", ""))
if not label or end <= start:
continue
if merged and merged[-1]["label"] == label and -1 <= start - merged[-1]["end"] <= 1:
merged[-1]["end"] = end
continue
merged.append({"start": start, "end": end, "label": label})
return merged
def _extract_spans(text: str) -> List[str]:
try:
ner = get_ner_pipeline()
entities = ner(text)
except Exception:
return []
merged = _merge_adjacent_entities(entities)
spans: List[str] = []
for ent in merged:
span = text[ent["start"]:ent["end"]].strip(" ,.;:\n\t")
span = strip_amounts_and_preps(span)
if len(span) >= 3:
spans.append(span)
return dedupe_preserve_order(spans)
def extract_ingredients(text: str, max_items: int = 48) -> List[str]:
"""Extract ingredient-like tokens from a recipe.
We keep the MVP behaviour simple: split the recipe into comma-separated
chunks, run RoBERTa NER per chunk, and fall back to the cleaned chunk when
the model misses it.
"""
chunks = tokenize_recipe_segments(text)
if not chunks:
chunks = [strip_amounts_and_preps(text)]
extracted: List[str] = []
for chunk in chunks:
if not chunk:
continue
spans = _extract_spans(chunk)
if spans:
extracted.extend(spans)
else:
cleaned = strip_amounts_and_preps(chunk)
if cleaned and len(cleaned) >= 2:
extracted.append(cleaned)
# Keep only meaningful spans; short fragments tend to be false positives.
cleaned = []
for item in dedupe_preserve_order(extracted):
if len(item) < 3:
continue
if item in {"and", "or", "the", "a", "an"}:
continue
cleaned.append(item)
return cleaned[:max_items]