Spaces:
Runtime error
Runtime error
| # modules/tagger_refinement.py | |
| import os | |
| import re | |
| import csv | |
| import joblib | |
| import numpy as np | |
| import pathlib | |
| import logging | |
| from collections import Counter, OrderedDict | |
| from lark import Lark, Token | |
| from lark.exceptions import ParseError | |
| import hnswlib | |
| import compress_fasttext | |
| from scipy.sparse import csr_matrix | |
| # --- LOGGING SETUP (Sessizleştirme) --- | |
| logging.basicConfig(level=logging.WARNING) | |
| for _name in ("gensim", "hnswlib"): | |
| logging.getLogger(_name).setLevel(logging.ERROR) | |
| # --- Dosya Yolları Sabitleri (Veri ve Modeller) --- | |
| PATH_TF_IDF = "models/binaries/tf_idf_files_420.joblib" | |
| PATH_HNSW_TAGS = "models/binaries/tfidf_hnsw_tags.bin" | |
| PATH_FASTTEXT = "models/binaries/e621FastTextModel010Replacement_small.bin" | |
| PATH_DB_RATING = "data/databases/word_rating_probabilities.csv" | |
| PATH_DB_FLUFFYROCK = "data/databases/fluffyrock_3m.csv" | |
| # --- GLOBAL MODEL DEĞİŞKENLERİ --- | |
| _HNSW_TAG = None | |
| _HNSW_DIM = None | |
| _HNSW_N_TAG = None | |
| _TF_IDF_COMPONENTS = None | |
| _FASTTEXT_MODEL = None | |
| _TAG_DATA_CACHE = None | |
| _NSFW_TAGS = None | |
| _ARTIST_SET = None | |
| # --- PARSER (Lark Grammar) --- | |
| grammar = r""" | |
| !start: (prompt | /[][():]/+)* | |
| prompt: (emphasized | plain | comma | WHITESPACE)* | |
| !emphasized: "(" prompt ")" | |
| | "(" prompt ":" [WHITESPACE] NUMBER [WHITESPACE] ")" | |
| comma: "," | |
| WHITESPACE: /\s+/ | |
| plain: /([^,\\\[\]():|]|\\.)+/ | |
| %import common.SIGNED_NUMBER -> NUMBER | |
| """ | |
| parser = Lark(grammar, start='start') | |
| def load_tf_idf_components(): | |
| global _TF_IDF_COMPONENTS, _NSFW_TAGS, _ARTIST_SET | |
| if _TF_IDF_COMPONENTS is None: | |
| fname = PATH_TF_IDF | |
| if os.path.exists(fname): | |
| try: | |
| _TF_IDF_COMPONENTS = joblib.load(fname) | |
| if 'tag_to_row_index' in _TF_IDF_COMPONENTS: | |
| _TF_IDF_COMPONENTS['row_to_tag'] = {idx: tag for tag, idx in _TF_IDF_COMPONENTS['tag_to_row_index'].items()} | |
| idf = _TF_IDF_COMPONENTS['idf'] | |
| if isinstance(idf, dict): | |
| t2c = _TF_IDF_COMPONENTS['tag_to_column_index'] | |
| n_cols = max(t2c.values()) + 1 | |
| idf_by_col = np.ones(n_cols, dtype=np.float32) | |
| for term, col in t2c.items(): idf_by_col[col] = float(idf.get(term, 1.0)) | |
| _TF_IDF_COMPONENTS['idf'] = idf_by_col | |
| except Exception: | |
| _TF_IDF_COMPONENTS = {} | |
| else: | |
| _TF_IDF_COMPONENTS = {} | |
| if _NSFW_TAGS is None: | |
| _NSFW_TAGS = set() | |
| if os.path.exists(PATH_DB_RATING): | |
| with open(PATH_DB_RATING, 'r', newline='', encoding='utf-8') as csvfile: | |
| reader = csv.reader(csvfile); next(reader, None) | |
| for row in reader: | |
| if float(row[1]) >= 0.95: _NSFW_TAGS.add(row[0]) | |
| if _ARTIST_SET is None: | |
| _ARTIST_SET = set() | |
| if os.path.exists(PATH_DB_FLUFFYROCK): | |
| with open(PATH_DB_FLUFFYROCK, 'r', newline='', encoding='utf-8') as csvfile: | |
| reader = csv.reader(csvfile) | |
| for row in reader: | |
| if row[0].startswith('by_'): _ARTIST_SET.add(row[0][3:]) | |
| def get_tag_data_cache(): | |
| global _TAG_DATA_CACHE | |
| if _TAG_DATA_CACHE is None: | |
| if not os.path.exists(PATH_DB_FLUFFYROCK): return None | |
| def build_aliases(rev=False): | |
| d = {} | |
| with open(PATH_DB_FLUFFYROCK, 'r', newline='', encoding='utf-8') as f: | |
| r = csv.reader(f) | |
| for row in r: | |
| tag = ''.join(c for c in row[0] if ord(c)<128) | |
| alist = [] if row[3] == "null" else [''.join(c for c in alias if ord(c)<128) for alias in row[3].split(',')] | |
| if rev: | |
| for a in alist: d.setdefault(a, []).append(tag) | |
| else: d[tag] = alist | |
| return d | |
| def build_counts(): | |
| d = {} | |
| with open(PATH_DB_FLUFFYROCK, 'r', newline='', encoding='utf-8') as f: | |
| r = csv.reader(f) | |
| for row in r: | |
| if row[2].isdigit(): d[row[0]] = int(row[2]) | |
| return d | |
| _TAG_DATA_CACHE = {'tag2aliases': build_aliases(), 'alias2tags': build_aliases(rev=True), 'tag2count': build_counts()} | |
| return _TAG_DATA_CACHE | |
| def is_artist(name): | |
| load_tf_idf_components() | |
| return name in _ARTIST_SET | |
| def _l2_normalize_rows(mat): | |
| mat = np.asarray(mat, dtype=np.float32) | |
| norms = np.linalg.norm(mat, axis=1, keepdims=True) | |
| norms[norms == 0.0] = 1.0 | |
| return mat / norms | |
| def _ensure_hnsw_indexes(): | |
| global _HNSW_TAG, _HNSW_DIM, _HNSW_N_TAG | |
| if _HNSW_TAG is not None: return | |
| load_tf_idf_components() | |
| if not _TF_IDF_COMPONENTS: return | |
| reduced_matrix = _TF_IDF_COMPONENTS.get('reduced_matrix') | |
| row_to_tag = _TF_IDF_COMPONENTS.get('row_to_tag') | |
| if reduced_matrix is None: return | |
| rm = _l2_normalize_rows(reduced_matrix).astype(np.float32) | |
| n_items, dim = rm.shape | |
| _HNSW_DIM = dim | |
| tag_rows = [] | |
| for i in range(n_items): | |
| tag = row_to_tag.get(i, "") | |
| base = tag[3:] if tag.startswith("by_") else tag | |
| if tag in {"by_unknown_artist", "by_conditional_dnp"}: tag_rows.append(i); continue | |
| if not is_artist(base): tag_rows.append(i) | |
| tag_path = pathlib.Path(PATH_HNSW_TAGS) | |
| idx = hnswlib.Index(space='cosine', dim=dim) | |
| if tag_path.exists(): | |
| try: | |
| idx.load_index(str(tag_path), max_elements=max(1, len(tag_rows))) | |
| idx.set_ef(200) | |
| _HNSW_TAG = idx | |
| _HNSW_N_TAG = len(tag_rows) | |
| except Exception: pass | |
| def _hnsw_query(vec, k=2000): | |
| _ensure_hnsw_indexes() | |
| if _HNSW_TAG is None: return [], [] | |
| q = np.asarray(vec, dtype=np.float32).reshape(-1) | |
| q_norm = np.linalg.norm(q) | |
| if q_norm > 0: q = q / q_norm | |
| labels, dists = _HNSW_TAG.knn_query(q, k=min(k, _HNSW_N_TAG)) | |
| return labels[0], 1.0 - dists[0] | |
| special_tags = ["score:0", "score:1", "score:2", "score:3", "score:4", "score:5", "score:6", "score:7", "score:8", "score:9", "rating:s", "rating:q", "rating:e"] | |
| MODEL_SPECIFIC_TAGS = {"masterpiece", "best quality", "good quality", "normal quality", "low quality", "worst quality", "highres", "lowres", "absurdres", "source_pony", "source_furry", "rating_safe", "rating_explicit", "rating_questionable"} | |
| def remove_special_tags(original_string): | |
| tags = [tag.strip() for tag in original_string.split(",")] | |
| remaining = [t for t in tags if t not in special_tags] | |
| return ", ".join(remaining) | |
| def extract_tags(tree): | |
| tags_with_positions = [] | |
| def _traverse(node): | |
| if isinstance(node, Token) and node.type == '__ANON_1': tags_with_positions.append((node.value, node.start_pos)) | |
| elif not isinstance(node, Token): | |
| for child in node.children: _traverse(child) | |
| _traverse(tree) | |
| return tags_with_positions | |
| def construct_pseudo_vector(pseudo_doc_terms): | |
| load_tf_idf_components() | |
| if not _TF_IDF_COMPONENTS: return None | |
| idf = _TF_IDF_COMPONENTS['idf'] | |
| t2c = _TF_IDF_COMPONENTS['tag_to_column_index'] | |
| cols, data = [], [] | |
| for term, w in pseudo_doc_terms.items(): | |
| j = t2c.get(term) | |
| if j is None: continue | |
| cols.append(j); data.append(w * idf[j]) | |
| return csr_matrix((data, cols, [0, len(cols)]), shape=(1, len(idf)), dtype=np.float32) | |
| def get_similar_tags_tfidf(pseudo_doc_terms): | |
| load_tf_idf_components() | |
| if not _TF_IDF_COMPONENTS: return {} | |
| pseudo = construct_pseudo_vector(pseudo_doc_terms) | |
| if pseudo is None: return {} | |
| svd = _TF_IDF_COMPONENTS['svd_model'] | |
| reduced = svd.transform(pseudo) | |
| top_inds, top_sims = _hnsw_query(reduced) | |
| row_to_tag = _TF_IDF_COMPONENTS['row_to_tag'] | |
| sim_dict = {} | |
| for i, sim in zip(top_inds, top_sims): | |
| tag = row_to_tag.get(int(i)) | |
| if tag and tag not in _NSFW_TAGS: sim_dict[tag] = float(sim) | |
| sorted_sim = OrderedDict(sorted(sim_dict.items(), key=lambda x: x[1], reverse=True)) | |
| return OrderedDict((k.replace('_', ' ').replace('(', '\\(').replace(')', '\\)'), v) for k, v in sorted_sim.items()) | |
| def rafine_etiketler(tags_string, context_weight=0.5): | |
| global _FASTTEXT_MODEL | |
| if not tags_string or not tags_string.strip(): return "" | |
| load_tf_idf_components() | |
| data_cache = get_tag_data_cache() | |
| if _FASTTEXT_MODEL is None and os.path.exists(PATH_FASTTEXT): | |
| try: _FASTTEXT_MODEL = compress_fasttext.models.CompressedFastTextKeyedVectors.load(PATH_FASTTEXT) | |
| except: pass | |
| if (not _FASTTEXT_MODEL or not data_cache or not _TF_IDF_COMPONENTS) and context_weight < 0.01: return tags_string | |
| try: | |
| input_str = remove_special_tags(tags_string.lower()) | |
| try: parsed = parser.parse(input_str) | |
| except ParseError: return tags_string | |
| raw_tags = extract_tags(parsed) | |
| tag_data = [] | |
| for tag_text, start_pos in raw_tags: | |
| mod_tag = tag_text.replace('_', ' ').replace('\\(', '(').replace('\\)', ')').strip() | |
| tf_idf_tag = re.sub(r'\\([()])', r'\1', re.sub(r' ', '_', tag_text.strip().removeprefix('by ').removeprefix('by_'))) | |
| tag_data.append({"original_tag": tag_text, "modified_tag": mod_tag, "tf_idf_matrix_tag": tf_idf_tag}) | |
| terms = [t["tf_idf_matrix_tag"] for t in tag_data] | |
| suggested_context = {} | |
| if _TF_IDF_COMPONENTS and _HNSW_TAG: suggested_context = get_similar_tags_tfidf(dict(Counter(terms))) | |
| valid_tags = [] | |
| for item in tag_data: | |
| orig, mod = item["original_tag"], item["modified_tag"] | |
| search = mod.replace(' ', '_') | |
| if mod in special_tags or orig in MODEL_SPECIFIC_TAGS: continue | |
| if not data_cache: | |
| valid_tags.append(orig); continue | |
| if is_artist(search.lower().removeprefix('by_')): | |
| valid_tags.append(orig); continue | |
| if search in data_cache['tag2count'] or search in data_cache['tag2aliases']: | |
| valid_tags.append(orig); continue | |
| if not _FASTTEXT_MODEL: | |
| if context_weight < 0.5: valid_tags.append(orig) | |
| continue | |
| similar_words = _FASTTEXT_MODEL.most_similar(search, topn=20) | |
| candidates = [] | |
| seen_cand = set() | |
| for sim_word, sim in similar_words: | |
| if sim_word in seen_cand: continue | |
| if sim_word in data_cache['tag2aliases']: | |
| candidates.append((sim_word.replace('_', ' '), sim)); seen_cand.add(sim_word) | |
| else: | |
| targets = data_cache['alias2tags'].get(sim_word, []) | |
| for t in targets: | |
| if t not in seen_cand: candidates.append((t.replace('_', ' '), sim)); seen_cand.add(t) | |
| scored_candidates = [] | |
| for word, sim in candidates: | |
| ctx_score = float(suggested_context.get(word) or suggested_context.get(word.replace('(', '\\(').replace(')', '\\)')) or 0.0) | |
| final_score = (1.0 - context_weight) * sim + (context_weight * ctx_score) | |
| scored_candidates.append((word, final_score)) | |
| scored_candidates.sort(key=lambda x: x[1], reverse=True) | |
| if scored_candidates: | |
| best_candidate, best_score = scored_candidates[0] | |
| if best_score > 0.6: valid_tags.append(best_candidate) | |
| return ", ".join(valid_tags) | |
| except Exception: return tags_string | |