ConvertAudioToJSON / extractors /supplier_extractor.py
VladGeekPro
AgainFixedUserAndSupplierProblem
cf7b5c2
raw
history blame
17.3 kB
"""
Экстрактор поставщиков из текста.
Использует комбинацию методов:
- TF-IDF для символьных n-грамм
- Фонетическое сравнение
- Выравнивание токенов
- Расстояние Левенштейна
"""
from __future__ import annotations
import re
import unicodedata
from typing import Any
import iuliia
from rapidfuzz import fuzz
from rapidfuzz.distance import Levenshtein
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from extractors.date_extractor import UniversalDateParser
def normalize_text(text: str) -> str:
"""Нормализует текст: lowercase, удаление диакритики и пунктуации."""
text = unicodedata.normalize("NFKD", text.lower())
text = "".join(ch for ch in text if not unicodedata.combining(ch))
return re.sub(r"[^\w\s]", "", text).strip()
def variants(text: str) -> list[str]:
"""Генерирует варианты текста (транслитерация)."""
base = normalize_text(text)
result = [base]
for schema in (iuliia.WIKIPEDIA, iuliia.MOSMETRO, iuliia.ALA_LC):
try:
v = normalize_text(schema.translate(base))
if v and v not in result:
result.append(v)
except Exception:
pass
for v in list(result):
core = " ".join(w for w in v.split() if len(w) > 1 and any(ch.isalpha() for ch in w))
core = normalize_text(core)
if core and core not in result:
result.insert(0, core)
return result
def token_alignment_score(phrase_variant: str, candidate_tokens: list[str]) -> float:
"""Вычисляет выравнивание токенов."""
phrase_tokens = [t for t in phrase_variant.split() if len(t) > 2]
if not phrase_tokens or not candidate_tokens:
return 0.0
best_scores = []
for pt in phrase_tokens:
best = 0.0
for ct in candidate_tokens:
sim = Levenshtein.normalized_similarity(pt, ct)
if sim > best:
best = sim
best_scores.append(best)
return sum(best_scores) / len(best_scores)
def length_penalty(phrase_len: int, candidate_len: int) -> float:
"""Штраф за разницу в длине."""
if phrase_len == 0 or candidate_len == 0:
return 0.0
ratio = min(phrase_len, candidate_len) / max(phrase_len, candidate_len)
if ratio >= 0.80:
return 1.0
if ratio >= 0.60:
return 0.90
if ratio >= 0.40:
return 0.70
return 0.50
def canonicalize_for_similarity(text: str) -> str:
"""Каноникализирует текст для фонетического сравнения."""
t = normalize_text(text).replace(" ", "")
replacements = (
("sch", "sh"),
("tch", "ch"),
("dzh", "j"),
("zh", "j"),
("sh", "s"),
("ch", "c"),
("kh", "h"),
("ph", "f"),
("ck", "k"),
("qu", "k"),
("q", "k"),
("w", "v"),
("x", "ks"),
("ts", "z"),
("tz", "z"),
)
for src, dst in replacements:
t = t.replace(src, dst)
return re.sub(r"(.)\1+", r"\1", t)
def phonetic_similarity(left: str, right: str) -> float:
"""Вычисляет фонетическую схожесть."""
l = canonicalize_for_similarity(left)
r = canonicalize_for_similarity(right)
if not l or not r:
return 0.0
char = fuzz.ratio(l, r) / 100.0
lev = Levenshtein.normalized_similarity(l, r)
return 0.50 * char + 0.50 * lev
class ExpenseSupplierExtractor:
"""
Экстрактор поставщиков из текста.
Ищет наиболее похожего поставщика из списка известных.
"""
def __init__(self, suppliers: list[str]) -> None:
self.suppliers = suppliers
self.sup_norm = [normalize_text(s) for s in suppliers]
self.sup_tokens = [s.split() for s in self.sup_norm]
self.sup_num_sets = [self.numeric_tokens(s) for s in self.sup_norm]
self.sup_number_tokens = {num for nums in self.sup_num_sets for num in nums}
self.supplier_lexicon = [
token
for token in sorted({tok for tokens in self.sup_tokens for tok in tokens})
if token and not token.isdigit()
]
self.tfidf = TfidfVectorizer(analyzer="char_wb", ngram_range=(3, 5))
self.sup_mat = self.tfidf.fit_transform(self.sup_norm)
self.max_words = max(len(s.split()) for s in self.sup_norm)
self.variant_cache: dict[str, list[str]] = {}
self.lexical_token_cache: dict[str, float] = {}
self.phrase_support_cache: dict[str, float] = {}
self.noise_terms = {
"для", "под", "над", "при", "без", "или",
"купил", "купила", "купили", "покупка", "заказал", "заказала", "заказали",
"оплатил", "оплатила", "оплатили", "заплатил", "заплатила", "заплатили",
"был", "была", "было", "были", "утром", "днем", "днём", "вечером", "ночью",
"товар", "товары", "продукт", "продукты", "десерт", "еда",
"лей", "лея", "леи", "целых", "сотых", "сом", "сомов", "руб", "рублей", "грн", "usd", "eur",
}
self.noise_terms.update(UniversalDateParser.temporal_vocabulary())
@staticmethod
def numeric_tokens(text: str) -> set[str]:
"""Извлекает числовые токены."""
return set(re.findall(r"\d+", text))
def cached_variants(self, text: str) -> list[str]:
"""Кэширует варианты текста."""
key = normalize_text(text)
cached = self.variant_cache.get(key)
if cached is None:
cached = variants(key)
self.variant_cache[key] = cached
return cached
@staticmethod
def split_words(text: str) -> list[str]:
"""Разбивает текст на слова."""
return [w for w in normalize_text(text).split() if w]
@classmethod
def is_supplier_extension(cls, base_supplier: str, extended_supplier: str) -> bool:
"""Проверяет, является ли один поставщик расширением другого."""
base_tokens = cls.split_words(base_supplier)
extended_tokens = cls.split_words(extended_supplier)
return len(base_tokens) < len(extended_tokens) and extended_tokens[:len(base_tokens)] == base_tokens
@classmethod
def phrase_token_count(cls, phrase: str | None) -> int:
"""Считает количество токенов во фразе."""
return len(cls.split_words(phrase or ""))
@classmethod
def resolve_overlapping_suppliers(cls, ranking: list[dict[str, Any]]) -> dict[str, Any]:
"""Разрешает конфликты между похожими поставщиками."""
if not ranking:
return {"supplier": None, "score": -1.0, "phrase": None}
best = ranking[0]
best_combined = float(best.get("combined", best.get("score", -1.0)))
best_phrase_len = cls.phrase_token_count(best.get("phrase"))
for alt in ranking[1:]:
if not cls.is_supplier_extension(str(best.get("supplier") or ""), str(alt.get("supplier") or "")):
continue
alt_combined = float(alt.get("combined", alt.get("score", -1.0)))
alt_phrase_len = cls.phrase_token_count(alt.get("phrase"))
if alt_phrase_len > best_phrase_len and alt_combined >= best_combined - 0.15:
best = alt
best_combined = alt_combined
best_phrase_len = alt_phrase_len
return best
@staticmethod
def numeric_compatibility_multiplier(phrase_nums: set[str], candidate_nums: set[str]) -> float:
"""Множитель совместимости числовых токенов."""
if not phrase_nums and not candidate_nums:
return 1.0
if phrase_nums == candidate_nums:
return 1.08
if phrase_nums and candidate_nums:
return 1.03 if phrase_nums & candidate_nums else 0.80
return 0.82
def lexical_support(self, phrase: str) -> float:
"""Вычисляет лексическую поддержку фразы."""
tokens = [token for token in normalize_text(phrase).split() if token and not token.isdigit()]
if not tokens or not self.supplier_lexicon:
return 0.0
support_scores: list[float] = []
for token in tokens:
cached = self.lexical_token_cache.get(token)
if cached is not None:
support_scores.append(cached)
continue
best = 0.0
for token_variant in self.cached_variants(token):
for lex in self.supplier_lexicon:
lev = Levenshtein.normalized_similarity(token_variant, lex)
phon = phonetic_similarity(token_variant, lex)
sim = max(lev, phon)
if sim > best:
best = sim
self.lexical_token_cache[token] = best
support_scores.append(best)
return sum(support_scores) / len(support_scores)
def score_phrase(self, phrase: str) -> dict[str, Any]:
"""Оценивает фразу на соответствие поставщикам."""
vs = self.cached_variants(phrase)
q = self.tfidf.transform(vs)
tf = cosine_similarity(q, self.sup_mat)
best: dict[str, Any] = {"supplier": None, "score": -1.0, "phrase": phrase, "variant": ""}
for i, cand in enumerate(self.sup_norm):
local = -1.0
local_variant = ""
candidate_nums = self.sup_num_sets[i]
for j, v in enumerate(vs):
char = fuzz.ratio(v, cand) / 100.0
tf_val = float(tf[j, i])
penalty = length_penalty(len(v), len(cand))
phon = phonetic_similarity(v, cand)
phrase_nums = self.numeric_tokens(v)
if len(v.split()) == 1 and len(cand.split()) == 1:
lev = Levenshtein.normalized_similarity(v, cand)
val = (0.45 * lev + 0.25 * char + 0.10 * tf_val + 0.20 * phon) * penalty
else:
align = token_alignment_score(v, self.sup_tokens[i])
tok = fuzz.token_set_ratio(v, cand) / 100.0
val = (0.30 * char + 0.20 * tok + 0.10 * tf_val + 0.20 * align + 0.20 * phon) * penalty
compact_v = v.replace(" ", "")
compact_cand = cand.replace(" ", "")
compact_char = fuzz.ratio(compact_v, compact_cand) / 100.0
compact_lev = Levenshtein.normalized_similarity(compact_v, compact_cand)
compact_phon = phonetic_similarity(compact_v, compact_cand)
compact = max(compact_char, compact_lev, compact_phon)
if compact > 0.55:
val = max(val, compact * penalty)
val *= self.numeric_compatibility_multiplier(phrase_nums, candidate_nums)
if val > local:
local = val
local_variant = v
if local > best["score"]:
best = {"supplier": self.suppliers[i], "score": local, "phrase": phrase, "variant": local_variant}
return best
def extract(
self,
text: str,
date_phrase: str | None = None,
excluded_phrases: list[str] | None = None,
debug: bool = False,
score_threshold: float = 0.50,
combined_threshold: float = 0.48,
) -> dict[str, Any]:
"""
Извлекает поставщика из текста.
Args:
text: Текст для анализа
date_phrase: Фраза даты для исключения
excluded_phrases: Дополнительные фразы для исключения
debug: Включить отладочную информацию
score_threshold: Минимальный raw-score для принятия совпадения
combined_threshold: Минимальный combined-score для принятия совпадения
Returns:
Словарь с supplier, supplier_score, matched_supplier_phrase
"""
excluded_tokens: set[str] = set()
if date_phrase:
excluded_tokens.update(normalize_text(date_phrase).split())
if excluded_phrases:
for phrase in excluded_phrases:
if phrase:
excluded_tokens.update(normalize_text(phrase).split())
excluded_tokens.update(self.noise_terms)
raw_tokens = normalize_text(text).split()
tokens: list[str] = []
for token in raw_tokens:
if token in excluded_tokens:
continue
if token.isdigit():
if token in self.sup_number_tokens:
tokens.append(token)
continue
if len(token) > 1:
tokens.append(token)
tokens = [t for t in tokens if (len(t) > 1 or t.isdigit()) and t not in excluded_tokens]
phrases: list[str] = []
seen: set[str] = set()
for i in range(len(tokens)):
for j in range(i + 1, min(i + 1 + self.max_words, len(tokens) + 1)):
p = " ".join(tokens[i:j])
if p not in seen:
seen.add(p)
phrases.append(p)
results = [self.score_phrase(p) for p in phrases]
candidate_rows: list[dict[str, Any]] = []
best_by_supplier: dict[str, dict[str, Any]] = {}
for row in results:
supplier = row["supplier"]
score = float(row.get("score", -1.0))
phrase = str(row.get("phrase") or "")
support = self.phrase_support_cache.get(phrase)
if support is None:
support = self.lexical_support(phrase)
self.phrase_support_cache[phrase] = support
combined = 0.75 * score + 0.25 * support
if debug:
candidate_rows.append({
"supplier": supplier,
"phrase": phrase,
"score": round(score, 4),
"support": round(support, 4),
"combined": round(combined, 4),
})
enriched = {**row, "combined": combined}
passes = score >= score_threshold or combined >= combined_threshold
if passes and (supplier not in best_by_supplier or combined > float(best_by_supplier[supplier].get("combined", -1.0))):
best_by_supplier[supplier] = enriched
if not best_by_supplier and results:
def support_for_phrase(phrase: str) -> float:
cached_support = self.phrase_support_cache.get(phrase)
if cached_support is None:
cached_support = self.lexical_support(phrase)
self.phrase_support_cache[phrase] = cached_support
return cached_support
fallback = max(
results,
key=lambda item: 0.75 * float(item.get("score", -1.0)) + 0.25 * support_for_phrase(str(item.get("phrase") or "")),
)
fallback_score = float(fallback.get("score", -1.0))
fallback_phrase = str(fallback.get("phrase") or "")
fallback_support = support_for_phrase(fallback_phrase)
fallback_combined = 0.75 * fallback_score + 0.25 * fallback_support
if fallback_score >= 0.40 and fallback_support >= 0.43 and fallback_combined >= 0.43:
best_by_supplier[fallback["supplier"]] = {**fallback, "combined": fallback_combined}
supplier_ranking = sorted(best_by_supplier.values(), key=lambda x: float(x.get("combined", x["score"])), reverse=True)
best = self.resolve_overlapping_suppliers(supplier_ranking)
payload = {
"supplier": best["supplier"],
"supplier_score": round(best["score"], 4) if best["score"] >= 0 else None,
"matched_supplier_phrase": best.get("phrase"),
}
if debug:
top_candidates = sorted(candidate_rows, key=lambda item: item["combined"], reverse=True)[:8]
payload["supplier_debug"] = {
"tokens": tokens,
"phrases_count": len(phrases),
"excluded_tokens": sorted(excluded_tokens)[:80],
"score_threshold": score_threshold,
"combined_threshold": combined_threshold,
"top_candidates": top_candidates,
}
return payload