mastermap-cleaner / src /data_pipeline.py
andrewbejjani's picture
Added functional doc in README.md and added basic
c6a3f44
import pandas as pd
import torch
import re
from collections import Counter
from sentence_transformers import util
from tqdm import tqdm
from src.utils import (
clean_degree_text,
normalize_text,
strip_degrees_for_search,
smart_format
)
from src.config import TOP_K_CANDIDATES, DEFAULT_SIMILARITY_THRESHOLD
def self_cluster_degrees(raw_degrees_list, model, school_cache, threshold=0.93):
"""Cluster similar degree labels inside one institution."""
cleaned_list = [clean_degree_text(raw) for raw in raw_degrees_list if isinstance(raw, str)]
raw_to_clean = {raw: clean_degree_text(raw) for raw in raw_degrees_list if isinstance(raw, str)}
clean_counts = Counter(cleaned_list)
unique_cleans = [deg for deg, count in clean_counts.most_common() if deg]
raw_to_meta = {}
if len(unique_cleans) <= 1:
for raw, clean in raw_to_clean.items():
raw_to_meta[raw] = (clean, "Degree_Formatter", "HIGH")
return raw_to_meta
embeddings = model.encode(unique_cleans, convert_to_tensor=True)
clean_to_clustered = {}
merge_info = {} # Track similarity scores for Blueprint transparency.
for i, current_deg in enumerate(unique_cleans):
if current_deg in clean_to_clustered: continue
clean_to_clustered[current_deg] = current_deg
if i + 1 < len(unique_cleans):
cos_scores = util.cos_sim(embeddings[i], embeddings[i+1:])[0]
for j, score in enumerate(cos_scores):
target_deg = unique_cleans[i + 1 + j]
if score.item() >= threshold and target_deg not in clean_to_clustered:
pair_key = f"{min(current_deg, target_deg)}|||{max(current_deg, target_deg)}"
# Runtime cache avoids repeated decisions within one run only.
cached_action = school_cache.get(pair_key)
if cached_action:
if cached_action == current_deg:
clean_to_clustered[target_deg] = current_deg
elif cached_action == target_deg:
clean_to_clustered[current_deg] = target_deg
merge_info[target_deg] = "Cached (Runtime)"
else:
clean_to_clustered[target_deg] = current_deg
school_cache[pair_key] = current_deg
merge_info[target_deg] = f"{score.item()*100:.1f}%"
for raw, clean in raw_to_clean.items():
final_val = clean
was_merged = False
sim_str = ""
while final_val in clean_to_clustered and clean_to_clustered[final_val] != final_val:
if not sim_str: sim_str = merge_info.get(final_val, "")
final_val = clean_to_clustered[final_val]
was_merged = True
conf = "MEDIUM" if was_merged else "HIGH"
if was_merged:
src = f"Auto-Merge ({sim_str})" if sim_str else "Auto-Merge"
else:
src = "Degree_Formatter"
raw_to_meta[raw] = (final_val, src, conf)
return raw_to_meta
def cluster_degrees_by_institution(df, degree_col, inst_col, model, master_cache, blueprint_data, threshold=0.93):
"""Apply degree clustering separately for each institution."""
print(f"\n[INFO] Auto-Clustering '{degree_col}'. (Merges will be logged to Blueprint...)")
cleaned_col_name = f'Cleaned_{degree_col}'
df[cleaned_col_name] = df[degree_col].copy()
unique_schools = df[inst_col].dropna().unique()
if "Degree_Decisions" not in master_cache: master_cache["Degree_Decisions"] = {}
school_mappings = {}
# Build school-specific mappings before mutating the dataframe.
for school in tqdm(unique_schools, desc=f"Mapping {degree_col}s by Institution"):
school_mask = (df[inst_col] == school) & (df[degree_col].notna())
raw_degs = df.loc[school_mask, degree_col].astype(str).tolist()
if not raw_degs: continue
if school not in master_cache["Degree_Decisions"]: master_cache["Degree_Decisions"][school] = {}
school_mappings[school] = self_cluster_degrees(raw_degs, model, master_cache["Degree_Decisions"][school], threshold)
# Apply mappings and log only changed/merged values for review.
for idx, row in tqdm(df.iterrows(), total=len(df), desc=f"Applying & Logging {degree_col}s"):
school = row[inst_col]
raw_deg = str(row[degree_col])
if pd.isna(row[degree_col]) or school not in school_mappings: continue
mapping_data = school_mappings[school].get(raw_deg)
if mapping_data:
final_val, src, conf = mapping_data
df.at[idx, cleaned_col_name] = final_val
if str(raw_deg).strip() != final_val.strip() or conf != "HIGH":
blueprint_data.append({
"Row_Index": idx + 3,
"Column": degree_col,
"Original_Raw_Text": raw_deg,
"AI_Suggested_Match": final_val,
"Human_Override": "",
"Confidence": conf,
"Match_Source": src
})
return df
def get_deterministic_match(value, combined_valid_targets):
"""Match obvious aliases/acronyms without calling embeddings or Groq."""
val_clean = normalize_text(value)
for target in combined_valid_targets:
target_clean = normalize_text(target)
if re.search(rf"^{re.escape(val_clean)}(\b|[\s\(\/\\\-])", target_clean): return target
for target in combined_valid_targets:
if f"({val_clean.upper()})" in normalize_text(target).upper(): return target
return None
def get_top_candidates(model, value, combined_valid_targets, reference_embeddings, k=5):
"""Return the nearest reference candidates for one raw value."""
if not combined_valid_targets: return []
query_embedding = model.encode(value, convert_to_tensor=True)
similarities = util.pytorch_cos_sim(query_embedding, reference_embeddings)[0]
actual_k = min(k, len(combined_valid_targets))
top_matches = torch.topk(similarities, actual_k)
return [combined_valid_targets[idx] for idx in top_matches.indices]
def get_dict_exact_match(value, combined_dict):
"""Exact match against alias keys first, then canonical values."""
value_clean = normalize_text(value)
for alias, canonical in combined_dict.items():
if normalize_text(alias) == value_clean:
return canonical
for canonical in combined_dict.values():
if normalize_text(canonical) == value_clean:
return canonical
return None
def get_dict_rule_match(value, combined_dict):
"""Rule match dictionary-style refs while returning canonical values."""
aliases = list(combined_dict.keys())
canonical_values = list(dict.fromkeys(combined_dict.values()))
alias_match = get_deterministic_match(value, aliases)
if alias_match:
return combined_dict[alias_match]
value_match = get_deterministic_match(value, canonical_values)
if value_match:
return value_match
return None
def as_reference_list(ref_data):
"""Convert list/dict reference data to display values."""
if isinstance(ref_data, list):
return ref_data
if isinstance(ref_data, dict):
return list(dict.fromkeys(ref_data.values()))
return []
def as_reference_dict(ref_data):
"""Convert list/dict reference data to an alias-to-canonical mapping."""
if isinstance(ref_data, dict):
return ref_data
if isinstance(ref_data, list):
return {item: item for item in ref_data if isinstance(item, str)}
return {}
def update_match_postfix(progress, source_counts):
"""Expose match-source counts in tqdm without noisy per-row prints."""
progress.set_postfix({
"Exact_Match": source_counts["Exact_Match"],
"Rule_Match": source_counts["Rule_Match"],
"LLM_Judged": source_counts["LLM_Judged"],
}, refresh=False)
def match_cache_key(column_name, value):
"""Return the single cache key format used for matching and reconstruction."""
if column_name in ["Institution", "Degree"]:
value = strip_degrees_for_search(value)
return normalize_text(str(value).rstrip("."))
def append_unique_cleaned_part(cleaned_parts, value):
"""Append comma-separated cleaned parts while preserving first-seen order."""
seen = set()
for existing_value in cleaned_parts:
for existing_part in str(existing_value).split(","):
key = normalize_text(existing_part.strip())
if key:
seen.add(key)
added = False
for part in str(value).split(","):
clean_part = part.strip()
if not clean_part:
continue
key = normalize_text(clean_part)
if key in seen:
continue
seen.add(key)
cleaned_parts.append(clean_part)
added = True
return added
def process_column(df, column_name, model, groq_router, official_refs, manual_refs, master_cache, split_pattern, blueprint_data):
"""Clean one dataframe column using refs, embeddings, then Groq fallback."""
if column_name not in df.columns: return df
core_data = official_refs.get(column_name, [])
added_data = manual_refs.get(column_name, [])
if column_name not in master_cache: master_cache[column_name] = {}
detailed_cache = {}
is_dict_mode = isinstance(core_data, dict)
def get_updated_embeddings():
"""Build current reference candidates after manual memory is loaded."""
if is_dict_mode:
c_dict = {**as_reference_dict(core_data), **as_reference_dict(added_data)}
c_keys = list(c_dict.keys())
u_vals = list(set(c_dict.values()))
k_emb = model.encode(c_keys, convert_to_tensor=True) if c_keys else None
v_emb = model.encode(u_vals, convert_to_tensor=True) if u_vals else None
return c_dict, c_keys, k_emb, u_vals, v_emb
else:
comb = as_reference_list(core_data) + as_reference_list(added_data)
comb = list(dict.fromkeys(item for item in comb if isinstance(item, str) and item.strip()))
emb = model.encode(comb, convert_to_tensor=True) if comb else None
return None, comb, emb, None, None
combined_dict, combined_valid_targets, reference_embeddings, unique_values, value_embeddings = get_updated_embeddings()
if is_dict_mode and not combined_dict:
raise ValueError(f"No dictionary references loaded for '{column_name}'. Refusing to call Groq for every value.")
if not is_dict_mode and not combined_valid_targets:
raise ValueError(f"No list references loaded for '{column_name}'. Refusing to call Groq for every value.")
# Work on unique split values first so repeated cells reuse one decision.
uniques = set()
for cell in df[column_name].dropna():
for p in re.split(split_pattern, str(cell), flags=re.IGNORECASE):
if p.strip(): uniques.add(p.strip())
print(f"\n[INFO] Analyzing {len(uniques)} unique entities in '{column_name}'...")
source_counts = Counter()
progress = tqdm(sorted(uniques, key=normalize_text), desc=f"Cleaning {column_name}")
for word in progress:
word_clean = match_cache_key(column_name, word)
# Fast path: reuse a decision made earlier in this run.
if word_clean in master_cache[column_name]:
detailed_cache[word_clean] = {"val": master_cache[column_name][word_clean], "src": "Memory_Cache", "conf": "HIGH"}
source_counts["Memory_Cache"] += 1
update_match_postfix(progress, source_counts)
continue
# Exact/rule matches are trusted and avoid LLM calls.
if is_dict_mode:
exact = get_dict_exact_match(word, combined_dict)
else:
exact = next((k for k in combined_valid_targets if normalize_text(k) == normalize_text(word_clean)), None) if combined_valid_targets else None
if exact:
val = exact
detailed_cache[word_clean] = {"val": val, "src": "Exact_Match", "conf": "HIGH"}
source_counts["Exact_Match"] += 1
update_match_postfix(progress, source_counts)
continue
if is_dict_mode:
suggested_match = get_dict_rule_match(word, combined_dict)
else:
suggested_match = get_deterministic_match(word, combined_valid_targets) if combined_valid_targets else None
if suggested_match:
detailed_cache[word_clean] = {"val": suggested_match, "src": "Rule_Match", "conf": "HIGH"}
source_counts["Rule_Match"] += 1
update_match_postfix(progress, source_counts)
continue
# Last resort: send only the top reference candidates to Groq.
candidates = []
if is_dict_mode:
cand_keys = get_top_candidates(model, word, combined_valid_targets, reference_embeddings)
cand_vals = get_top_candidates(model, word, unique_values, value_embeddings)
candidates = list(dict.fromkeys(cand_keys + cand_vals))[:TOP_K_CANDIDATES]
else:
candidates = get_top_candidates(model, word, combined_valid_targets, reference_embeddings)
ans_val, src, conf = groq_router.ask_judge(word, candidates, column_name)
source_counts[src] += 1
update_match_postfix(progress, source_counts)
# Re-check Groq output against refs so canonical casing/names are preserved.
if "API_Error" not in conf and ans_val != "UNKNOWN" and ans_val != "LLM_Failed":
llm_parts = [p.strip() for p in ans_val.split(",")]
corrected_parts = []
all_matched = True # Flag to track if every piece exists in our data
for part in llm_parts:
if is_dict_mode:
exact_match = get_dict_exact_match(part, combined_dict)
if exact_match:
corrected_parts.append(exact_match)
else:
rule_match = get_dict_rule_match(part, combined_dict)
if rule_match:
corrected_parts.append(rule_match)
else:
corrected_parts.append(part)
all_matched = False
else:
exact_match = next((c for c in candidates if c.lower() == part.lower()), None)
if exact_match:
corrected_parts.append(exact_match)
else:
rule_match = get_deterministic_match(part, candidates)
if rule_match:
corrected_parts.append(rule_match)
else:
# Keep unverifiable LLM text, but do not upgrade confidence.
corrected_parts.append(part)
all_matched = False
unique_parts = list(dict.fromkeys(corrected_parts))
ans_val = ", ".join(unique_parts)
raw_parts_for_check = [
p.strip()
for p in re.split(split_pattern, str(word))
if p.strip()
] or [word]
raw_lookup_keys = {normalize_text(part) for part in raw_parts_for_check}
verified_lookup_keys = {normalize_text(part) for part in unique_parts}
# Only upgrade when the LLM answer was verified against the refs and
# also matches the original text directly. Otherwise it is still an
# LLM judgment, even if the chosen answer exists in the references.
if all_matched and verified_lookup_keys.issubset(raw_lookup_keys):
conf = "HIGH"
src = "LLM_Rule_Verified"
# -----------------------------------------------------
detailed_cache[word_clean] = {"val": ans_val, "src": src, "conf": conf}
# Reconstruct full cell values in original row order for workbook injection.
for idx, row in df.iterrows():
cell_val = row[column_name]
if pd.isna(cell_val): continue
raw_parts = [p.strip() for p in re.split(split_pattern, str(cell_val)) if p.strip()]
cleaned_parts = []
lowest_conf = "HIGH"
cell_sources = []
i = 0
while i < len(raw_parts):
curr = raw_parts[i]
# Recover obvious accidental splits such as "University of, Manchester".
if i + 1 < len(raw_parts):
combo_clean = match_cache_key(column_name, f"{curr}, {raw_parts[i+1]}")
if combo_clean in detailed_cache:
append_unique_cleaned_part(cleaned_parts, detailed_cache[combo_clean]["val"])
if detailed_cache[combo_clean]["conf"] != "HIGH": lowest_conf = detailed_cache[combo_clean]["conf"]
cell_sources.append(detailed_cache[combo_clean]["src"])
i += 2; continue
p_clean = match_cache_key(column_name, curr)
if len(p_clean) <= 1: i += 1; continue
cache_hit = detailed_cache.get(p_clean)
if cache_hit:
append_unique_cleaned_part(cleaned_parts, cache_hit["val"])
if cache_hit["conf"] != "HIGH": lowest_conf = cache_hit["conf"]
cell_sources.append(cache_hit["src"])
else:
append_unique_cleaned_part(cleaned_parts, smart_format(curr))
lowest_conf = "LOW"
cell_sources.append("Fallback_Format")
i += 1
final_stitched_val = ", ".join(cleaned_parts)
df.at[idx, f'Cleaned_{column_name}'] = final_stitched_val
# Review every changed cell and every low/medium-confidence result.
if str(cell_val).strip() != final_stitched_val.strip() or lowest_conf != "HIGH":
blueprint_data.append({
"Row_Index": idx + 3,
"Column": column_name,
"Original_Raw_Text": cell_val,
"AI_Suggested_Match": final_stitched_val,
"Human_Override": "",
"Match_Source": " | ".join(set(cell_sources)),
"Confidence": lowest_conf
})
return df